diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml
index cd12393cde..d1a372f059 100644
--- a/.github/workflows/actions.yml
+++ b/.github/workflows/actions.yml
@@ -76,7 +76,7 @@ jobs:
cache-dependency-path: go.sum
- name: Run Go Test
- run: go test -coverprofile coverage.out -timeout 20m `go list ./... | grep -v e2e`
+ run: go test -coverprofile coverage.out -timeout 25m `go list ./... | grep -v e2e`
- name: Upload coverage file to Codecov
uses: codecov/codecov-action@v3
diff --git a/.golangci.yml b/.golangci.yml
index 17b704c548..f2e11b8a47 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -51,9 +51,15 @@ issues:
exclude-rules:
- path: _test\.go
linters:
+ - dupl
+ - errorlint
+ - forcetypeassert
- gosec
- - unparam
- lll
+ - nlreturn
+ - prealloc
+ - unparam
+ - wsl
- linters:
- staticcheck
path: "state/runtime/precompiled/base.go"
diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go
index 08eb94d6a9..c9bf3ed498 100644
--- a/blockchain/blockchain.go
+++ b/blockchain/blockchain.go
@@ -15,6 +15,7 @@ import (
"github.com/dogechain-lab/dogechain/contracts/upgrader"
"github.com/dogechain-lab/dogechain/contracts/validatorset"
"github.com/dogechain-lab/dogechain/helper/common"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/types"
"github.com/dogechain-lab/dogechain/types/buildroot"
@@ -118,7 +119,7 @@ func (b *Blockchain) updateGasPriceAvg(newValues []*big.Int) {
defer b.gpAverage.Unlock()
// Sum the values for quick reference
- sum := big.NewInt(0)
+ sum := new(big.Int)
for _, val := range newValues {
sum = sum.Add(sum, val)
}
@@ -194,19 +195,16 @@ func (b *Blockchain) GetAvgGasPrice() *big.Int {
func NewBlockchain(
logger hclog.Logger,
config *chain.Chain,
- storageBuilder storage.StorageBuilder,
consensus Verifier,
+ db storage.Storage,
executor Executor,
metrics *Metrics,
) (*Blockchain, error) {
- if storageBuilder == nil {
- return nil, ErrNilStorageBuilder
- }
-
b := &Blockchain{
logger: logger.Named("blockchain"),
config: config,
consensus: consensus,
+ db: db,
executor: executor,
stream: newEventStream(context.Background()),
gpAverage: &gasPriceAverage{
@@ -216,17 +214,6 @@ func NewBlockchain(
metrics: NewDummyMetrics(metrics),
}
- var (
- db storage.Storage
- err error
- )
-
- if db, err = storageBuilder.Build(); err != nil {
- return nil, err
- }
-
- b.db = db
-
if err := b.initCaches(defaultCacheSize); err != nil {
return nil, err
}
@@ -483,7 +470,7 @@ func (b *Blockchain) writeCanonicalHeader(event *Event, h *types.Header) error {
return fmt.Errorf("parent difficulty not found")
}
- newTD := big.NewInt(0).Add(parentTD, new(big.Int).SetUint64(h.Difficulty))
+ newTD := new(big.Int).Add(parentTD, new(big.Int).SetUint64(h.Difficulty))
if err := b.db.WriteCanonicalHeader(h, newTD); err != nil {
return err
}
@@ -515,7 +502,7 @@ func (b *Blockchain) advanceHead(newHeader *types.Header) (*big.Int, error) {
}
// Check if there was a parent difficulty
- parentTD := big.NewInt(0)
+ parentTD := new(big.Int)
if newHeader.ParentHash != types.StringToHash("") {
td, ok := b.readTotalDifficulty(newHeader.ParentHash)
@@ -527,7 +514,7 @@ func (b *Blockchain) advanceHead(newHeader *types.Header) (*big.Int, error) {
}
// Calculate the new total difficulty
- newTD := big.NewInt(0).Add(parentTD, big.NewInt(0).SetUint64(newHeader.Difficulty))
+ newTD := new(big.Int).Add(parentTD, new(big.Int).SetUint64(newHeader.Difficulty))
if err := b.db.WriteTotalDifficulty(newHeader.Hash, newTD); err != nil {
return nil, err
}
@@ -717,11 +704,15 @@ func (b *Blockchain) VerifyFinalizedBlock(block *types.Block) error {
return ErrNoBlockHeader
}
+ b.logger.Debug("verify finalized block header", "number", block.Number())
+
// Make sure the consensus layer verifies this block header
if err := b.consensus.VerifyHeader(block.Header); err != nil {
return fmt.Errorf("failed to verify the header: %w", err)
}
+ b.logger.Debug("verify finalized block body", "number", block.Number())
+
// Do the initial block verification
if err := b.verifyBlock(block); err != nil {
return err
@@ -1072,6 +1063,7 @@ func (b *Blockchain) extractBlockReceipts(block *types.Block) ([]*types.Receipt,
// Check the cache for the block receipts
receipts, ok := b.receiptsCache.Get(block.Header.Hash)
if !ok {
+ b.logger.Info("execute block transactions due to no receipts cache")
// No receipts found in the cache, execute the transactions from the block
// and fetch them
blockResult, err := b.executeBlockTransactions(block)
@@ -1136,7 +1128,7 @@ func (b *Blockchain) writeBody(block *types.Block) error {
// ReadTxLookup returns the block hash using the transaction hash
func (b *Blockchain) ReadTxLookup(hash types.Hash) (types.Hash, bool) {
if b.isStopped() {
- return types.ZeroHash, false
+ return types.Hash{}, false
}
v, ok := b.db.ReadTxLookup(hash)
@@ -1185,43 +1177,15 @@ func (b *Blockchain) verifyGasLimit(header, parentHeader *types.Header) error {
return nil
}
-// GetHashHelper is used by the EVM, so that the SC can get the hash of the header number
-func (b *Blockchain) GetHashHelper(header *types.Header) func(i uint64) (res types.Hash) {
- return func(i uint64) (res types.Hash) {
- num, hash := header.Number-1, header.ParentHash
-
- for {
- if num == i {
- res = hash
-
- return
- }
-
- h, ok := b.GetHeaderByHash(hash)
- if !ok {
- return
- }
-
- hash = h.ParentHash
-
- if num == 0 {
- return
- }
-
- num--
- }
- }
-}
-
// GetHashByNumber returns the block hash using the block number
func (b *Blockchain) GetHashByNumber(blockNumber uint64) types.Hash {
if b.isStopped() {
- return types.ZeroHash
+ return types.Hash{}
}
block, ok := b.GetBlockByNumber(blockNumber, false)
if !ok {
- return types.ZeroHash
+ return types.Hash{}
}
return block.Hash()
@@ -1260,9 +1224,9 @@ func (b *Blockchain) writeHeaderImpl(evnt *Event, header *types.Header) error {
// Write the difficulty
if err := b.db.WriteTotalDifficulty(
header.Hash,
- big.NewInt(0).Add(
+ new(big.Int).Add(
parentTD,
- big.NewInt(0).SetUint64(header.Difficulty),
+ new(big.Int).SetUint64(header.Difficulty),
),
); err != nil {
return err
@@ -1282,7 +1246,7 @@ func (b *Blockchain) writeHeaderImpl(evnt *Event, header *types.Header) error {
// Update the headers cache
b.headersCache.Add(header.Hash, header)
- incomingTD := big.NewInt(0).Add(parentTD, big.NewInt(0).SetUint64(header.Difficulty))
+ incomingTD := new(big.Int).Add(parentTD, new(big.Int).SetUint64(header.Difficulty))
if incomingTD.Cmp(currentTD) > 0 {
// new block has higher difficulty, reorg the chain
if err := b.handleReorg(evnt, currentHeader, header); err != nil {
@@ -1305,7 +1269,8 @@ func (b *Blockchain) writeHeaderImpl(evnt *Event, header *types.Header) error {
func (b *Blockchain) writeFork(header *types.Header) error {
forks, err := b.db.ReadForks()
if err != nil {
- if errors.Is(err, storage.ErrNotFound) {
+ // too many error types
+ if err.Error() == rawdb.ErrNotFound.Error() {
forks = []types.Hash{}
} else {
return err
@@ -1484,8 +1449,7 @@ func (b *Blockchain) Close() error {
b.wg.Wait()
- // close db at last
- return b.db.Close()
+ return nil
}
func (b *Blockchain) stop() {
diff --git a/blockchain/blockchain_test.go b/blockchain/blockchain_test.go
index 995c525c96..74e3b67787 100644
--- a/blockchain/blockchain_test.go
+++ b/blockchain/blockchain_test.go
@@ -10,9 +10,10 @@ import (
"github.com/dogechain-lab/dogechain/blockchain/storage"
"github.com/dogechain-lab/dogechain/blockchain/storage/kvstorage"
"github.com/dogechain-lab/dogechain/chain"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/types"
- "github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
)
@@ -490,7 +491,7 @@ func TestInsertHeaders(t *testing.T) {
assert.Equal(t, head.Hash, expected.Hash)
forks, err := b.GetForks()
- if err != nil && !errors.Is(err, storage.ErrNotFound) {
+ if err != nil && (err.Error() != rawdb.ErrNotFound.Error()) {
t.Fatal(err)
}
@@ -545,8 +546,7 @@ func TestForkUnknownParents(t *testing.T) {
}
func TestBlockchainWriteBody(t *testing.T) {
- storage, err := kvstorage.NewMemoryStorageBuilder(hclog.NewNullLogger()).Build()
- assert.NoError(t, err)
+ storage := kvstorage.NewKeyValueStorage(memorydb.New())
b := &Blockchain{
db: storage,
@@ -704,8 +704,8 @@ func TestBlockchain_VerifyBlockParent(t *testing.T) {
t.Parallel()
emptyHeader := &types.Header{
- Hash: types.ZeroHash,
- ParentHash: types.ZeroHash,
+ Hash: types.Hash{},
+ ParentHash: types.Hash{},
}
emptyHeader.ComputeHash()
@@ -729,7 +729,7 @@ func TestBlockchain_VerifyBlockParent(t *testing.T) {
// Create a dummy block
block := &types.Block{
Header: &types.Header{
- ParentHash: types.ZeroHash,
+ ParentHash: types.Hash{},
},
}
@@ -855,8 +855,8 @@ func TestBlockchain_VerifyBlockBody(t *testing.T) {
t.Parallel()
emptyHeader := &types.Header{
- Hash: types.ZeroHash,
- ParentHash: types.ZeroHash,
+ Hash: types.Hash{},
+ ParentHash: types.Hash{},
}
t.Run("Invalid SHA3 Uncles root", func(t *testing.T) {
@@ -869,7 +869,7 @@ func TestBlockchain_VerifyBlockBody(t *testing.T) {
block := &types.Block{
Header: &types.Header{
- Sha3Uncles: types.ZeroHash,
+ Sha3Uncles: types.Hash{},
},
}
diff --git a/blockchain/storage/errors.go b/blockchain/storage/errors.go
deleted file mode 100644
index 47c3afdb4f..0000000000
--- a/blockchain/storage/errors.go
+++ /dev/null
@@ -1,5 +0,0 @@
-package storage
-
-import "fmt"
-
-var ErrNotFound = fmt.Errorf("not found")
diff --git a/blockchain/storage/kvstorage/kvstorage.go b/blockchain/storage/kvstorage/kvstorage.go
index 6883110312..1647c7313c 100644
--- a/blockchain/storage/kvstorage/kvstorage.go
+++ b/blockchain/storage/kvstorage/kvstorage.go
@@ -1,182 +1,91 @@
-//nolint:stylecheck
package kvstorage
import (
- "encoding/binary"
"math/big"
"github.com/dogechain-lab/dogechain/blockchain/storage"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/types"
- "github.com/dogechain-lab/fastrlp"
- "github.com/hashicorp/go-hclog"
)
-// Prefixes for the key-value store
-var (
- // DIFFICULTY is the difficulty prefix
- DIFFICULTY = []byte("d")
-
- // HEADER is the header prefix
- HEADER = []byte("h")
-
- // HEAD is the chain head prefix
- HEAD = []byte("o")
-
- // FORK is the entry to store forks
- FORK = []byte("f")
-
- // CANONICAL is the prefix for the canonical chain numbers
- CANONICAL = []byte("c")
-
- // BODY is the prefix for bodies
- BODY = []byte("b")
-
- // RECEIPTS is the prefix for receipts
- RECEIPTS = []byte("r")
-
- // SNAPSHOTS is the prefix for snapshots
- SNAPSHOTS = []byte("s")
-
- // TX_LOOKUP_PREFIX is the prefix for transaction lookups
- TX_LOOKUP_PREFIX = []byte("l")
-)
-
-// Sub-prefixes
-var (
- HASH = []byte("hash")
- NUMBER = []byte("number")
- EMPTY = []byte("empty")
-)
-
-// KV is a generic key-value store, need close it
-type KV interface {
- Close() error
-
- Set(p []byte, v []byte) error
- Get(p []byte) ([]byte, bool, error)
-}
-
// KeyValueStorage is a generic storage for kv databases
type KeyValueStorage struct {
- logger hclog.Logger
- db KV
-}
-
-func newKeyValueStorage(logger hclog.Logger, db KV) storage.Storage {
- return &KeyValueStorage{logger: logger, db: db}
-}
-
-func (s *KeyValueStorage) encodeUint(n uint64) []byte {
- b := make([]byte, 8)
- binary.BigEndian.PutUint64(b[:], n)
-
- return b[:]
+ db kvdb.KVBatchStorage
}
-func (s *KeyValueStorage) decodeUint(b []byte) uint64 {
- return binary.BigEndian.Uint64(b[:])
+func NewKeyValueStorage(db kvdb.KVBatchStorage) storage.Storage {
+ return &KeyValueStorage{db: db}
}
// -- canonical hash --
// ReadCanonicalHash gets the hash from the number of the canonical chain
func (s *KeyValueStorage) ReadCanonicalHash(n uint64) (types.Hash, bool) {
- data, ok := s.get(CANONICAL, s.encodeUint(n))
- if !ok {
- return types.Hash{}, false
- }
-
- return types.BytesToHash(data), true
+ return rawdb.ReadCanonicalHash(s.db, n)
}
// WriteCanonicalHash writes a hash for a number block in the canonical chain
func (s *KeyValueStorage) WriteCanonicalHash(n uint64, hash types.Hash) error {
- return s.set(CANONICAL, s.encodeUint(n), hash.Bytes())
+ return rawdb.WriteCanonicalHash(s.db, n, hash)
}
// HEAD //
// ReadHeadHash returns the hash of the head
func (s *KeyValueStorage) ReadHeadHash() (types.Hash, bool) {
- data, ok := s.get(HEAD, HASH)
- if !ok {
- return types.Hash{}, false
- }
-
- return types.BytesToHash(data), true
+ return rawdb.ReadHeadHash(s.db)
}
// ReadHeadNumber returns the number of the head
func (s *KeyValueStorage) ReadHeadNumber() (uint64, bool) {
- data, ok := s.get(HEAD, NUMBER)
- if !ok {
- return 0, false
- }
-
- if len(data) != 8 {
- return 0, false
- }
-
- return s.decodeUint(data), true
+ return rawdb.ReadHeadNumber(s.db)
}
// WriteHeadHash writes the hash of the head
func (s *KeyValueStorage) WriteHeadHash(h types.Hash) error {
- return s.set(HEAD, HASH, h.Bytes())
+ return rawdb.WriteHeadHash(s.db, h)
}
// WriteHeadNumber writes the number of the head
func (s *KeyValueStorage) WriteHeadNumber(n uint64) error {
- return s.set(HEAD, NUMBER, s.encodeUint(n))
+ return rawdb.WriteHeadNumber(s.db, n)
}
// FORK //
// WriteForks writes the current forks
func (s *KeyValueStorage) WriteForks(forks []types.Hash) error {
- ff := storage.Forks(forks)
-
- return s.writeRLP(FORK, EMPTY, &ff)
+ return rawdb.WriteForks(s.db, forks)
}
// ReadForks read the current forks
func (s *KeyValueStorage) ReadForks() ([]types.Hash, error) {
- forks := &storage.Forks{}
- err := s.readRLP(FORK, EMPTY, forks)
-
- return *forks, err
+ return rawdb.ReadForks(s.db)
}
// DIFFICULTY //
// WriteTotalDifficulty writes the difficulty
func (s *KeyValueStorage) WriteTotalDifficulty(hash types.Hash, diff *big.Int) error {
- return s.set(DIFFICULTY, hash.Bytes(), diff.Bytes())
+ return rawdb.WriteTotalDifficulty(s.db, hash, diff)
}
// ReadTotalDifficulty reads the difficulty
func (s *KeyValueStorage) ReadTotalDifficulty(hash types.Hash) (*big.Int, bool) {
- v, ok := s.get(DIFFICULTY, hash.Bytes())
- if !ok {
- return nil, false
- }
-
- return big.NewInt(0).SetBytes(v), true
+ return rawdb.ReadTotalDifficulty(s.db, hash)
}
// HEADER //
// WriteHeader writes the header
func (s *KeyValueStorage) WriteHeader(h *types.Header) error {
- return s.writeRLP(HEADER, h.Hash.Bytes(), h)
+ return rawdb.WriteHeader(s.db, h.Hash, h)
}
// ReadHeader reads the header
func (s *KeyValueStorage) ReadHeader(hash types.Hash) (*types.Header, error) {
- header := &types.Header{}
- err := s.readRLP(HEADER, hash.Bytes(), header)
-
- return header, err
+ return rawdb.ReadHeader(s.db, hash)
}
// WriteCanonicalHeader implements the storage interface
@@ -208,141 +117,34 @@ func (s *KeyValueStorage) WriteCanonicalHeader(h *types.Header, diff *big.Int) e
// WriteBody writes the body
func (s *KeyValueStorage) WriteBody(hash types.Hash, body *types.Body) error {
- return s.writeRLP(BODY, hash.Bytes(), body)
+ return rawdb.WriteBody(s.db, hash, body)
}
// ReadBody reads the body
func (s *KeyValueStorage) ReadBody(hash types.Hash) (*types.Body, error) {
- body := &types.Body{}
- err := s.readRLP(BODY, hash.Bytes(), body)
-
- return body, err
+ return rawdb.ReadBody(s.db, hash)
}
// RECEIPTS //
// WriteReceipts writes the receipts
func (s *KeyValueStorage) WriteReceipts(hash types.Hash, receipts []*types.Receipt) error {
- rr := types.Receipts(receipts)
-
- return s.writeRLP(RECEIPTS, hash.Bytes(), &rr)
+ return rawdb.WriteReceipts(s.db, hash, receipts)
}
// ReadReceipts reads the receipts
func (s *KeyValueStorage) ReadReceipts(hash types.Hash) ([]*types.Receipt, error) {
- receipts := &types.Receipts{}
- err := s.readRLP(RECEIPTS, hash.Bytes(), receipts)
-
- return *receipts, err
+ return rawdb.ReadReceipts(s.db, hash)
}
// TX LOOKUP //
// WriteTxLookup maps the transaction hash to the block hash
func (s *KeyValueStorage) WriteTxLookup(hash types.Hash, blockHash types.Hash) error {
- ar := &fastrlp.Arena{}
- vr := ar.NewBytes(blockHash.Bytes())
-
- return s.write2(TX_LOOKUP_PREFIX, hash.Bytes(), vr)
+ return rawdb.WriteTxLookup(s.db, hash, blockHash)
}
// ReadTxLookup reads the block hash using the transaction hash
func (s *KeyValueStorage) ReadTxLookup(hash types.Hash) (types.Hash, bool) {
- parser := &fastrlp.Parser{}
-
- v := s.read2(TX_LOOKUP_PREFIX, hash.Bytes(), parser)
- if v == nil {
- return types.Hash{}, false
- }
-
- blockHash := []byte{}
- blockHash, err := v.GetBytes(blockHash[:0], 32)
-
- if err != nil {
- panic(err)
- }
-
- return types.BytesToHash(blockHash), true
-}
-
-// WRITE OPERATIONS //
-
-func (s *KeyValueStorage) writeRLP(p, k []byte, raw types.RLPMarshaler) error {
- var data []byte
- if obj, ok := raw.(types.RLPStoreMarshaler); ok {
- data = obj.MarshalStoreRLPTo(nil)
- } else {
- data = raw.MarshalRLPTo(nil)
- }
-
- return s.set(p, k, data)
-}
-
-func (s *KeyValueStorage) readRLP(p, k []byte, raw types.RLPUnmarshaler) error {
- p = append(p, k...)
- data, ok, err := s.db.Get(p)
-
- if err != nil {
- return err
- }
-
- if !ok {
- return storage.ErrNotFound
- }
-
- if obj, ok := raw.(types.RLPStoreUnmarshaler); ok {
- // decode in the store format
- if err := obj.UnmarshalStoreRLP(data); err != nil {
- return err
- }
- } else {
- // normal rlp decoding
- if err := raw.UnmarshalRLP(data); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (s *KeyValueStorage) read2(p, k []byte, parser *fastrlp.Parser) *fastrlp.Value {
- data, ok := s.get(p, k)
- if !ok {
- return nil
- }
-
- v, err := parser.Parse(data)
- if err != nil {
- return nil
- }
-
- return v
-}
-
-func (s *KeyValueStorage) write2(p, k []byte, v *fastrlp.Value) error {
- dst := v.MarshalTo(nil)
-
- return s.set(p, k, dst)
-}
-
-func (s *KeyValueStorage) set(p []byte, k []byte, v []byte) error {
- p = append(p, k...)
-
- return s.db.Set(p, v)
-}
-
-func (s *KeyValueStorage) get(p []byte, k []byte) ([]byte, bool) {
- p = append(p, k...)
- data, ok, err := s.db.Get(p)
-
- if err != nil {
- return nil, false
- }
-
- return data, ok
-}
-
-// Close closes the connection with the db
-func (s *KeyValueStorage) Close() error {
- return s.db.Close()
+ return rawdb.ReadTxLookup(s.db, hash)
}
diff --git a/blockchain/storage/kvstorage/kvstorage_test.go b/blockchain/storage/kvstorage/kvstorage_test.go
new file mode 100644
index 0000000000..dd0629784a
--- /dev/null
+++ b/blockchain/storage/kvstorage/kvstorage_test.go
@@ -0,0 +1,24 @@
+package kvstorage
+
+import (
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/blockchain/storage"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+)
+
+func newStorage(t *testing.T) storage.Storage {
+ t.Helper()
+
+ db := memorydb.New()
+
+ t.Cleanup(func() {
+ db.Close()
+ })
+
+ return NewKeyValueStorage(db)
+}
+
+func TestLevelDBStorage(t *testing.T) {
+ storage.TestStorage(t, newStorage)
+}
diff --git a/blockchain/storage/kvstorage/leveldb.go b/blockchain/storage/kvstorage/leveldb.go
deleted file mode 100644
index f8f45c3d26..0000000000
--- a/blockchain/storage/kvstorage/leveldb.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package kvstorage
-
-import (
- "github.com/dogechain-lab/dogechain/blockchain/storage"
- "github.com/dogechain-lab/dogechain/helper/kvdb"
- "github.com/hashicorp/go-hclog"
-)
-
-type leveldbStorageBuilder struct {
- logger hclog.Logger
- leveldbBuilder kvdb.LevelDBBuilder
-}
-
-func (builder *leveldbStorageBuilder) Build() (storage.Storage, error) {
- db, err := builder.leveldbBuilder.Build()
- if err != nil {
- return nil, err
- }
-
- return newKeyValueStorage(builder.logger.Named("leveldb"), db), nil
-}
-
-// NewLevelDBStorageBuilder creates the new blockchain storage builder
-func NewLevelDBStorageBuilder(logger hclog.Logger, leveldbBuilder kvdb.LevelDBBuilder) storage.StorageBuilder {
- return &leveldbStorageBuilder{
- logger: logger,
- leveldbBuilder: leveldbBuilder,
- }
-}
diff --git a/blockchain/storage/kvstorage/leveldb_test.go b/blockchain/storage/kvstorage/leveldb_test.go
deleted file mode 100644
index 2b2738637e..0000000000
--- a/blockchain/storage/kvstorage/leveldb_test.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package kvstorage
-
-import (
- "os"
- "testing"
-
- "github.com/dogechain-lab/dogechain/blockchain/storage"
- "github.com/dogechain-lab/dogechain/helper/kvdb"
- "github.com/hashicorp/go-hclog"
-)
-
-func newLevelDBStorage(t *testing.T) (storage.Storage, func()) {
- t.Helper()
-
- path, err := os.MkdirTemp("/tmp", "minimal_storage")
- if err != nil {
- t.Fatal(err)
- }
-
- logger := hclog.NewNullLogger()
-
- s, err := NewLevelDBStorageBuilder(
- logger, kvdb.NewLevelDBBuilder(logger, path)).Build()
- if err != nil {
- t.Fatal(err)
- }
-
- closeFn := func() {
- if err := s.Close(); err != nil {
- t.Fatal(err)
- }
-
- if err := os.RemoveAll(path); err != nil {
- t.Fatal(err)
- }
- }
-
- return s, closeFn
-}
-
-func TestLevelDBStorage(t *testing.T) {
- storage.TestStorage(t, newLevelDBStorage)
-}
diff --git a/blockchain/storage/kvstorage/memory.go b/blockchain/storage/kvstorage/memory.go
deleted file mode 100644
index 7177625fc6..0000000000
--- a/blockchain/storage/kvstorage/memory.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package kvstorage
-
-import (
- "github.com/dogechain-lab/dogechain/blockchain/storage"
- "github.com/dogechain-lab/dogechain/helper/hex"
- "github.com/hashicorp/go-hclog"
-)
-
-type memoryStorageBuilder struct {
- logger hclog.Logger
-}
-
-func (builder *memoryStorageBuilder) Build() (storage.Storage, error) {
- db := &memoryKV{map[string][]byte{}}
-
- return newKeyValueStorage(builder.logger, db), nil
-}
-
-// NewMemoryStorageBuilder creates the new blockchain storage builder
-func NewMemoryStorageBuilder(logger hclog.Logger) storage.StorageBuilder {
- return &memoryStorageBuilder{
- logger: logger,
- }
-}
-
-// memoryKV is an in memory implementation of the kv storage
-type memoryKV struct {
- db map[string][]byte
-}
-
-func (m *memoryKV) Set(p []byte, v []byte) error {
- m.db[hex.EncodeToHex(p)] = v
-
- return nil
-}
-
-func (m *memoryKV) Get(p []byte) ([]byte, bool, error) {
- v, ok := m.db[hex.EncodeToHex(p)]
- if !ok {
- return nil, false, nil
- }
-
- return v, true, nil
-}
-
-func (m *memoryKV) Close() error {
- return nil
-}
diff --git a/blockchain/storage/kvstorage/memory_test.go b/blockchain/storage/kvstorage/memory_test.go
deleted file mode 100644
index ec28115f0f..0000000000
--- a/blockchain/storage/kvstorage/memory_test.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package kvstorage
-
-import (
- "testing"
-
- "github.com/dogechain-lab/dogechain/blockchain/storage"
- "github.com/hashicorp/go-hclog"
-)
-
-func TestMemoryStorage(t *testing.T) {
- t.Helper()
-
- f := func(t *testing.T) (storage.Storage, func()) {
- t.Helper()
-
- s, _ := NewMemoryStorageBuilder(hclog.NewNullLogger()).Build()
-
- return s, func() {}
- }
- storage.TestStorage(t, f)
-}
diff --git a/blockchain/storage/storage.go b/blockchain/storage/storage.go
index 0af8bc4655..6607779e63 100644
--- a/blockchain/storage/storage.go
+++ b/blockchain/storage/storage.go
@@ -7,10 +7,6 @@ import (
"github.com/hashicorp/go-hclog"
)
-type StorageBuilder interface {
- Build() (Storage, error)
-}
-
// Storage is a generic blockchain storage
type Storage interface {
ReadCanonicalHash(n uint64) (types.Hash, bool)
@@ -40,8 +36,6 @@ type Storage interface {
WriteTxLookup(hash types.Hash, blockHash types.Hash) error
ReadTxLookup(hash types.Hash) (types.Hash, bool)
-
- Close() error
}
// Factory is a factory method to create a blockchain storage
diff --git a/blockchain/storage/testing.go b/blockchain/storage/testing.go
index 56fb7be68e..b4c439f2ef 100644
--- a/blockchain/storage/testing.go
+++ b/blockchain/storage/testing.go
@@ -11,7 +11,7 @@ import (
"github.com/stretchr/testify/assert"
)
-type PlaceholderStorage func(t *testing.T) (Storage, func())
+type PlaceholderStorage func(t *testing.T) Storage
var (
addr1 = types.StringToAddress("1")
@@ -54,8 +54,7 @@ func TestStorage(t *testing.T, m PlaceholderStorage) {
func testCanonicalChain(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
var cases = []struct {
Number uint64
@@ -107,8 +106,7 @@ func testCanonicalChain(t *testing.T, m PlaceholderStorage) {
func testDifficulty(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
var cases = []struct {
Diff *big.Int
@@ -154,8 +152,7 @@ func testDifficulty(t *testing.T, m PlaceholderStorage) {
func testHead(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
for i := uint64(0); i < 5; i++ {
h := &types.Header{
@@ -199,8 +196,7 @@ func testHead(t *testing.T, m PlaceholderStorage) {
func testForks(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
var cases = []struct {
Forks []types.Hash
@@ -226,8 +222,7 @@ func testForks(t *testing.T, m PlaceholderStorage) {
func testHeader(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
extraData, _ := hex.DecodeHex("0x11bbe8db4e347b4e8c937c1c8370e4b5ed33adb3db69cbdb7a38e1e50b1b82fa")
header := &types.Header{
@@ -255,8 +250,7 @@ func testHeader(t *testing.T, m PlaceholderStorage) {
func testBody(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
header := &types.Header{
Number: 5,
@@ -322,8 +316,7 @@ func testBody(t *testing.T, m PlaceholderStorage) {
func testReceipts(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
h := &types.Header{
Difficulty: 133,
@@ -398,8 +391,7 @@ func testReceipts(t *testing.T, m PlaceholderStorage) {
func testWriteCanonicalHeader(t *testing.T, m PlaceholderStorage) {
t.Helper()
- s, closeFn := m(t)
- defer closeFn()
+ s := m(t)
h := &types.Header{
Number: 100,
diff --git a/blockchain/testing.go b/blockchain/testing.go
index 3bd281bf66..b056da9c45 100644
--- a/blockchain/testing.go
+++ b/blockchain/testing.go
@@ -10,6 +10,7 @@ import (
"github.com/dogechain-lab/dogechain/blockchain/storage"
"github.com/dogechain-lab/dogechain/blockchain/storage/kvstorage"
"github.com/dogechain-lab/dogechain/chain"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
"github.com/dogechain-lab/dogechain/state"
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
"github.com/dogechain-lab/dogechain/types"
@@ -106,8 +107,21 @@ func NewTestBlockchain(t *testing.T, headers []*types.Header) *Blockchain {
},
}
- st := itrie.NewStateDB(itrie.NewMemoryStorage(), hclog.NewNullLogger(), nil)
- b, err := newBlockChain(config, state.NewExecutor(config.Params, st, hclog.NewNullLogger()))
+ st := itrie.NewStateDB(
+ memorydb.New(),
+ hclog.NewNullLogger(),
+ nil,
+ )
+
+ // no need to set up snapshot in test
+ b, err := newBlockChain(
+ config,
+ state.NewExecutor(
+ config.Params,
+ hclog.NewNullLogger(),
+ st,
+ ),
+ )
if err != nil {
t.Fatal(err)
@@ -349,8 +363,8 @@ func newBlockChain(config *chain.Chain, executor Executor) (*Blockchain, error)
b, err := NewBlockchain(
hclog.NewNullLogger(),
config,
- kvstorage.NewMemoryStorageBuilder(hclog.NewNullLogger()),
&MockVerifier{},
+ kvstorage.NewKeyValueStorage(memorydb.New()),
executor,
NilMetrics(),
)
diff --git a/command/server/config.go b/command/server/config.go
index 1c926ba300..dbeea24148 100644
--- a/command/server/config.go
+++ b/command/server/config.go
@@ -5,6 +5,7 @@ import (
"fmt"
"io/ioutil"
"strings"
+ "time"
"github.com/dogechain-lab/dogechain/jsonrpc"
"github.com/dogechain-lab/dogechain/network"
@@ -14,29 +15,50 @@ import (
// Config defines the server configuration params
type Config struct {
- GenesisPath string `json:"chain_config"`
- SecretsConfigPath string `json:"secrets_config"`
- DataDir string `json:"data_dir"`
- BlockGasTarget string `json:"block_gas_target"`
- GRPCAddr string `json:"grpc_addr"`
- JSONRPCAddr string `json:"jsonrpc_addr"`
- Telemetry *Telemetry `json:"telemetry"`
- Network *Network `json:"network"`
- ShouldSeal bool `json:"seal"`
- TxPool *TxPool `json:"tx_pool"`
- LogLevel string `json:"log_level"`
- RestoreFile string `json:"restore_file"`
- BlockTime uint64 `json:"block_time_s"`
- Headers *Headers `json:"headers"`
- LogFilePath string `json:"log_to"`
- EnableGraphQL bool `json:"enable_graphql"`
- GraphQLAddr string `json:"graphql_addr"`
- JSONRPCBatchRequestLimit uint64 `json:"json_rpc_batch_request_limit" yaml:"json_rpc_batch_request_limit"`
- JSONRPCBlockRangeLimit uint64 `json:"json_rpc_block_range_limit" yaml:"json_rpc_block_range_limit"`
- JSONNamespace string `json:"json_namespace" yaml:"json_namespace"`
- EnableWS bool `json:"enable_ws" yaml:"enable_ws"`
- EnablePprof bool `json:"enable_pprof" yaml:"enable_pprof"`
- BlockBroadcast bool `json:"enable_block_broadcast" yaml:"enable_block_broadcast"`
+ GenesisPath string `json:"chain_config"`
+ SecretsConfigPath string `json:"secrets_config"`
+ DataDir string `json:"data_dir"`
+ BlockGasTarget string `json:"block_gas_target"`
+ GRPCAddr string `json:"grpc_addr"`
+ JSONRPCAddr string `json:"jsonrpc_addr"`
+ Telemetry *Telemetry `json:"telemetry"`
+ Network *Network `json:"network"`
+ ShouldSeal bool `json:"seal"`
+ TxPool *TxPool `json:"tx_pool"`
+ LogLevel string `json:"log_level"`
+ RestoreFile string `json:"restore_file"`
+ BlockTime uint64 `json:"block_time_s"`
+ Headers *Headers `json:"headers"`
+ LogFilePath string `json:"log_to"`
+ EnableGraphQL bool `json:"enable_graphql"`
+ GraphQLAddr string `json:"graphql_addr"`
+ JSONRPCBatchRequestLimit uint64 `json:"json_rpc_batch_request_limit" yaml:"json_rpc_batch_request_limit"`
+ JSONRPCBlockRangeLimit uint64 `json:"json_rpc_block_range_limit" yaml:"json_rpc_block_range_limit"`
+ JSONNamespace string `json:"json_namespace" yaml:"json_namespace"`
+ EnableWS bool `json:"enable_ws" yaml:"enable_ws"`
+ EnablePprof bool `json:"enable_pprof" yaml:"enable_pprof"`
+ BlockBroadcast bool `json:"enable_block_broadcast" yaml:"enable_block_broadcast"`
+ EnableSnapshot bool `json:"enable_snapshot" yaml:"enable_snapshot"`
+ SnapshotAsyncBuild bool `json:"snapshot_async_build" yaml:"snapshot_async_build"`
+ CacheConfig *CacheConfig `json:"cache_config" yaml:"cache_config"`
+}
+
+type CacheConfig struct {
+ Cache int `json:"cache" yaml:"cache"`
+ SnapshotPercentage int `json:"snapshot_percentage" yaml:"snapshot_percentage"`
+ TrieCleanPercentage int `json:"trie_clean_percentage" yaml:"trie_clean_percentage"`
+ // TrieDirtyPercentage int `json:"trie_dirty_percentage" yaml:"trie_dirty_percentage"`
+ TrieCleanRejournalRaw string `json:"trie_clean_rejournal_raw" yaml:"trie_clean_rejournal_raw"`
+ // TrieTimeoutRaw string `json:"trie_timeout_raw" yaml:"trie_timeout_raw"`
+
+ SnapshotCache int `json:"snapshot_cache" yaml:"snapshot_cache"`
+ TrieCleanCache int `json:"trie_clean_cache" yaml:"trie_clean_cache"`
+ // Disk journal directory for trie cache to survive node restarts
+ TrieCleanCacheJournal string `json:"trie_clean_cache_journal" yaml:"trie_clean_cache_journal"`
+ // Time interval to regenerate the journal for clean cache
+ TrieCleanCacheRejournal time.Duration `json:"trie_clean_cache_rejournal" yaml:"trie_clean_cache_rejournal"`
+ TrieDirtyCache int `json:"trie_dirty_cache" yaml:"trie_dirty_cache"`
+ TrieTimeout time.Duration `json:"trie_timeout" yaml:"trie_timeout"`
}
// Telemetry holds the config details for metric services.
diff --git a/command/server/init.go b/command/server/init.go
index 70ea3cf504..84fbd9c4a5 100644
--- a/command/server/init.go
+++ b/command/server/init.go
@@ -5,6 +5,7 @@ import (
"fmt"
"math"
"net"
+ "time"
"github.com/dogechain-lab/dogechain/network/common"
@@ -19,6 +20,8 @@ import (
var (
errInvalidBlockTime = errors.New("invalid block time specified")
errDataDirectoryUndefined = errors.New("data directory not defined")
+ errInvalidCacheSize = errors.New("invalid cache size")
+ errInvalidPercentage = errors.New("invalid cache percentage")
)
func (p *serverParams) initConfigFromFile() error {
@@ -59,9 +62,51 @@ func (p *serverParams) initRawParams() error {
p.initPeerLimits()
p.initLogFileLocation()
+ if err := p.initCacheConfigs(); err != nil {
+ return err
+ }
+
return p.initAddresses()
}
+func (p *serverParams) initCacheConfigs() error {
+ if p.rawConfig.CacheConfig.Cache <= 0 {
+ return errInvalidCacheSize
+ }
+
+ // snapshot cache
+ if p.rawConfig.CacheConfig.SnapshotPercentage < 0 {
+ return errInvalidPercentage
+ } else {
+ p.rawConfig.CacheConfig.SnapshotCache = p.rawConfig.CacheConfig.Cache *
+ p.rawConfig.CacheConfig.SnapshotPercentage / 100
+ }
+
+ // trie clean cache
+ if p.rawConfig.CacheConfig.TrieCleanPercentage < 0 {
+ return errInvalidPercentage
+ } else {
+ p.rawConfig.CacheConfig.TrieCleanCache = p.rawConfig.CacheConfig.Cache *
+ p.rawConfig.CacheConfig.TrieCleanPercentage / 100
+ }
+
+ // trie dirty cache
+ p.rawConfig.CacheConfig.TrieDirtyCache = p.rawConfig.CacheConfig.Cache -
+ p.rawConfig.CacheConfig.SnapshotCache - p.rawConfig.CacheConfig.TrieCleanCache
+
+ // clean cache rejournal duration
+ d, err := time.ParseDuration(p.rawConfig.CacheConfig.TrieCleanRejournalRaw)
+ if err != nil {
+ return err
+ } else {
+ p.rawConfig.CacheConfig.TrieCleanCacheRejournal = d
+ }
+
+ p.rawConfig.CacheConfig.TrieTimeout = 5 * time.Minute
+
+ return nil
+}
+
func (p *serverParams) initBlockTime() error {
if p.rawConfig.BlockTime < 1 {
return errInvalidBlockTime
diff --git a/command/server/params.go b/command/server/params.go
index 4e64d7de90..f4e52d4f78 100644
--- a/command/server/params.go
+++ b/command/server/params.go
@@ -4,6 +4,7 @@ import (
"errors"
"log"
"net"
+ "path"
"strings"
"github.com/hashicorp/go-hclog"
@@ -53,6 +54,13 @@ const (
jsonrpcNamespaceFlag = "json-rpc-namespace"
enableWSFlag = "enable-ws"
blockBroadcastFlag = "block-broadcast"
+ enableSnapshotFlag = "enable-snapshot"
+ snapshotAsyncBuildFlag = "snapshot.async-build"
+ cacheFlag = "cache"
+ cacheSnapshotFlag = "cache.snapshot"
+ cacheTrieCleanFlag = "cache.trie"
+ cacheTrieCleanJournalFlag = "cache.trie.journal"
+ cacheTrieCleanRejournalFlag = "cache.trie.rejournal"
)
const (
@@ -62,9 +70,10 @@ const (
var (
params = &serverParams{
rawConfig: &Config{
- Telemetry: &Telemetry{},
- Network: &Network{},
- TxPool: &TxPool{},
+ Telemetry: &Telemetry{},
+ Network: &Network{},
+ TxPool: &TxPool{},
+ CacheConfig: &CacheConfig{},
},
}
)
@@ -203,6 +212,16 @@ func (p *serverParams) generateConfig() *server.Config {
ingoreCIDRs = append(ingoreCIDRs, ipnet)
}
+ if !p.rawConfig.EnableSnapshot {
+ // no sync mode yet, simply disable snapshot cache
+ cfg := p.rawConfig.CacheConfig
+ cfg.TrieCleanCache += cfg.SnapshotCache
+ cfg.SnapshotCache = 0 // Disable
+ }
+
+ // trie journal dir
+ trieJournalDir := path.Join(p.rawConfig.DataDir, p.rawConfig.CacheConfig.TrieCleanCacheJournal)
+
return &server.Config{
Chain: chainCfg,
JSONRPC: &server.JSONRPC{
@@ -262,5 +281,15 @@ func (p *serverParams) generateConfig() *server.Config {
Daemon: p.isDaemon,
ValidatorKey: p.validatorKey,
BlockBroadcast: p.rawConfig.BlockBroadcast,
+ EnableSnapshot: p.rawConfig.EnableSnapshot,
+ CacheConfig: &server.CacheConfig{
+ TrieCleanLimit: p.rawConfig.CacheConfig.TrieCleanCache,
+ TrieCleanJournal: trieJournalDir,
+ TrieCleanRejournal: p.rawConfig.CacheConfig.TrieCleanCacheRejournal,
+ TrieDirtyLimit: p.rawConfig.CacheConfig.TrieDirtyCache,
+ TrieTimeLimit: p.rawConfig.CacheConfig.TrieTimeout,
+ SnapshotLimit: p.rawConfig.CacheConfig.SnapshotCache,
+ SnapshotWait: !p.rawConfig.SnapshotAsyncBuild, // the opposite flag
+ },
}
}
diff --git a/command/server/server.go b/command/server/server.go
index b4387c80ea..5f9a639dc4 100644
--- a/command/server/server.go
+++ b/command/server/server.go
@@ -13,7 +13,7 @@ import (
"github.com/dogechain-lab/dogechain/command/helper"
"github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/helper/daemon"
- "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/leveldb"
"github.com/dogechain-lab/dogechain/network"
"github.com/dogechain-lab/dogechain/server"
"github.com/dogechain-lab/dogechain/txpool"
@@ -134,42 +134,42 @@ func setFlags(cmd *cobra.Command) {
cmd.Flags().IntVar(
¶ms.leveldbCacheSize,
leveldbCacheFlag,
- kvdb.DefaultLevelDBCache,
+ leveldb.DefaultCache,
"the size of the leveldb cache in MB",
)
cmd.Flags().IntVar(
¶ms.leveldbHandles,
leveldbHandlesFlag,
- kvdb.DefaultLevelDBHandles,
+ leveldb.DefaultHandles,
"the number of handles to leveldb open files",
)
cmd.Flags().IntVar(
¶ms.leveldbBloomKeyBits,
leveldbBloomKeyBitsFlag,
- kvdb.DefaultLevelDBBloomKeyBits,
+ leveldb.DefaultBloomKeyBits,
"the bits of leveldb bloom filters",
)
cmd.Flags().IntVar(
¶ms.leveldbTableSize,
leveldbTableSizeFlag,
- kvdb.DefaultLevelDBCompactionTableSize,
+ leveldb.DefaultCompactionTableSize,
"the leveldb 'sorted table' size in MB",
)
cmd.Flags().IntVar(
¶ms.leveldbTotalTableSize,
leveldbTotalTableSizeFlag,
- kvdb.DefaultLevelDBCompactionTotalSize,
+ leveldb.DefaultCompactionTotalSize,
"limits leveldb total size of 'sorted table' for each level in MB",
)
cmd.Flags().BoolVar(
¶ms.leveldbNoSync,
leveldbNoSyncFlag,
- kvdb.DefaultLevelDBNoSync,
+ leveldb.DefaultNoSyncFlag,
"leveldb nosync allows completely disable fsync",
)
}
@@ -350,6 +350,58 @@ func setFlags(cmd *cobra.Command) {
}
}
+ // cache flags
+ {
+ cmd.Flags().BoolVar(
+ ¶ms.rawConfig.EnableSnapshot,
+ enableSnapshotFlag,
+ false,
+ "(experimental) Enables snapshot-database mode",
+ )
+
+ cmd.Flags().BoolVar(
+ ¶ms.rawConfig.SnapshotAsyncBuild,
+ snapshotAsyncBuildFlag,
+ false,
+ "(experimental) Enables snapshot asynchrous (background) build to avoid halting",
+ )
+
+ cmd.Flags().IntVar(
+ ¶ms.rawConfig.CacheConfig.Cache,
+ cacheFlag,
+ 1024,
+ "Megabytes of memory allocated to internal caching",
+ )
+
+ cmd.Flags().IntVar(
+ ¶ms.rawConfig.CacheConfig.SnapshotPercentage,
+ cacheSnapshotFlag,
+ 40,
+ "Percentage of cache memory allowance to use for snapshot caching",
+ )
+
+ cmd.Flags().IntVar(
+ ¶ms.rawConfig.CacheConfig.TrieCleanPercentage,
+ cacheTrieCleanFlag,
+ 40,
+ "Percentage of cache memory allowance to use for trie caching",
+ )
+
+ cmd.Flags().StringVar(
+ ¶ms.rawConfig.CacheConfig.TrieCleanCacheJournal,
+ cacheTrieCleanJournalFlag,
+ server.TrieCacheDir,
+ "Disk journal directory for trie cache to survive node restarts",
+ )
+
+ cmd.Flags().StringVar(
+ ¶ms.rawConfig.CacheConfig.TrieCleanRejournalRaw,
+ cacheTrieCleanRejournalFlag,
+ "1h0m0s",
+ "Time interval to regenerate the trie cache journal",
+ )
+ }
+
setDevFlags(cmd)
}
diff --git a/consensus/ibft/consensus.go b/consensus/ibft/consensus.go
index 3cee4793cf..bf89895b79 100644
--- a/consensus/ibft/consensus.go
+++ b/consensus/ibft/consensus.go
@@ -12,7 +12,7 @@ import (
"github.com/dogechain-lab/dogechain/types"
)
-// runSequence starts the underlying consensus mechanism for the given height.
+// runSequenceAtHeight starts the underlying consensus mechanism for the given height.
// It may be called by a single thread at any given time
func (i *Ibft) runSequence(height uint64) <-chan struct{} {
done := make(chan struct{})
@@ -61,7 +61,7 @@ func (i *Ibft) runSequenceAtHeight(ctx context.Context, height uint64) {
default:
}
- if done := i.runCycle(ctx); done {
+ if isDone := i.runCycle(ctx); isDone {
return
}
}
@@ -149,15 +149,7 @@ func (i *Ibft) runAcceptState(ctx context.Context) (shouldStop bool) { // start
return
}
- // update current module cache
- if err := i.updateCurrentModules(number); err != nil {
- logger.Error(
- "failed to update submodules",
- "height", number,
- "err", err,
- )
- }
-
+ // snapshot is already updated when syncing, no need to update module here
snap, err := i.getSnapshot(parent.Number)
if err != nil {
@@ -245,7 +237,7 @@ func (i *Ibft) runAcceptState(ctx context.Context) (shouldStop bool) { // start
for i.getState() == currentstate.AcceptState {
msg, continuable := i.getNextMessage(ctx, timeout)
if !continuable {
- return true
+ return
}
if msg == nil {
@@ -328,7 +320,7 @@ func (i *Ibft) runAcceptState(ctx context.Context) (shouldStop bool) { // start
}
}
- return false
+ return
}
// runValidateState implements the Validate state loop.
@@ -379,7 +371,7 @@ func (i *Ibft) runValidateState(ctx context.Context) (shouldStop bool) {
msg, continuable := i.getNextMessage(ctx, timeout)
if !continuable {
- return true
+ return
}
if msg == nil {
diff --git a/consensus/ibft/ibft.go b/consensus/ibft/ibft.go
index 384c6deb97..a21cb431d1 100644
--- a/consensus/ibft/ibft.go
+++ b/consensus/ibft/ibft.go
@@ -239,11 +239,6 @@ func (i *Ibft) Initialize() error {
return err
}
- // // initialize fork manager
- // if err := i.forkManager.Initialize(); err != nil {
- // return err
- // }
-
// Set up the snapshots
if err := i.setupSnapshot(); err != nil {
return err
@@ -287,7 +282,7 @@ func (i *Ibft) startSyncing() {
// update module cache
if err := i.updateCurrentModules(blockNumber + 1); err != nil {
- logger.Error("failed to update sub modules", "height", blockNumber+1, "err", err)
+ logger.Warn("failed to update sub modules when syncing", "height", blockNumber+1, "err", err)
}
// reset headers of txpool
@@ -535,8 +530,8 @@ func (i *Ibft) startConsensus() {
)
if err := i.updateCurrentModules(pending); err != nil {
- i.logger.Error(
- "failed to update submodules",
+ i.logger.Warn(
+ "failed to update submodules in consensus",
"height", pending,
"err", err,
)
@@ -555,6 +550,8 @@ func (i *Ibft) startConsensus() {
i.stopSequence()
i.logger.Info("canceled sequence", "sequence", pending)
}
+
+ i.logger.Info("sequence canceled due to new block", "sequence", pending)
case <-sequenceCh:
case <-i.closeCh:
if isValidator {
@@ -575,8 +572,8 @@ func (i *Ibft) isValidSnapshot() bool {
// check if we are a validator and enabled
header := i.blockchain.Header()
- snap, err := i.getSnapshot(header.Number)
+ snap, err := i.getSnapshot(header.Number)
if err != nil {
return false
}
@@ -907,10 +904,10 @@ func (i *Ibft) updateCurrentModules(height uint64) error {
i.currentValidators = snap.Set
- i.logger.Info("update current module",
- "height", height,
- "validators", i.currentValidators,
- )
+ // i.logger.Debug("update current module",
+ // "height", height,
+ // "validators", i.currentValidators,
+ // )
return nil
}
@@ -1242,6 +1239,7 @@ var (
errIncorrectBlockHeight = errors.New("proposed block number is incorrect")
errBlockVerificationFailed = errors.New("block verification failed")
errFailedToInsertBlock = errors.New("failed to insert block")
+ errFailedToGetUpdateLock = errors.New("failed to get update lock")
)
func (i *Ibft) handleStateErr(err error) {
diff --git a/consensus/ibft/ibft_test.go b/consensus/ibft/ibft_test.go
index 7087536901..3a233689e1 100644
--- a/consensus/ibft/ibft_test.go
+++ b/consensus/ibft/ibft_test.go
@@ -105,7 +105,7 @@ func (m *MockBlockchain) SetGenesis(validators []types.Address) *types.Block {
header := &types.Header{
Number: 0,
Difficulty: 0,
- ParentHash: types.ZeroHash,
+ ParentHash: types.Hash{},
MixHash: IstanbulDigest,
Sha3Uncles: types.EmptyUncleHash,
GasLimit: defaultBlockGasLimit,
@@ -527,10 +527,10 @@ func TestTransition_AcceptState_Reject_WrongHeight_Block(t *testing.T) {
proposeBlockHeight uint64 = 3
// The latest block in the chain
- latestBlock = blockchain.MockBlock(nextSequence-1, types.ZeroHash, pool.get("B").priv, pool.ValidatorSet())
+ latestBlock = blockchain.MockBlock(nextSequence-1, types.Hash{}, pool.get("B").priv, pool.ValidatorSet())
// The next proposed block in the network
- proposedBlock = blockchain.MockBlock(proposeBlockHeight, types.ZeroHash, pool.get("C").priv, pool.ValidatorSet())
+ proposedBlock = blockchain.MockBlock(proposeBlockHeight, types.Hash{}, pool.get("C").priv, pool.ValidatorSet())
)
i.state.SetView(proto.ViewMsg(nextSequence, 0))
diff --git a/crypto/crypto.go b/crypto/crypto.go
index 88d04f4f36..9508d7c26f 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -198,23 +198,36 @@ var hasherPool = sync.Pool{
New: func() interface{} { return sha3.NewLegacyKeccak256() },
}
-// Keccak256 calculates the Keccak256
-func Keccak256(v ...[]byte) []byte {
- h, ok := hasherPool.Get().(hash.Hash)
+func NewKeccakState() KeccakState {
+ hasher, ok := hasherPool.Get().(KeccakState)
if !ok {
- h = sha3.NewLegacyKeccak256()
+ //nolint:forcetypeassert
+ hasher = sha3.NewLegacyKeccak256().(KeccakState)
}
+ return hasher
+}
+
+// Keccak256 calculates the Keccak256
+func Keccak256(v ...[]byte) []byte {
+ return Keccak256Hash(v...).Bytes()
+}
+
+func Keccak256Hash(v ...[]byte) (h types.Hash) {
+ hasher := NewKeccakState()
+
defer func() {
- h.Reset()
- hasherPool.Put(h)
+ hasher.Reset()
+ hasherPool.Put(hasher)
}()
for _, i := range v {
- h.Write(i)
+ hasher.Write(i)
}
- return h.Sum(nil)
+ hasher.Read(h[:])
+
+ return h
}
// PubKeyToAddress returns the Ethereum address of a public key
@@ -314,3 +327,11 @@ func ReadConsensusKey(manager secrets.SecretsManager) (*ecdsa.PrivateKey, error)
return BytesToPrivateKey(validatorKey)
}
+
+// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports
+// Read to get a variable amount of data from the hash state. Read is faster than Sum
+// because it doesn't copy the internal state, but also modifies the internal state.
+type KeccakState interface {
+ hash.Hash
+ Read([]byte) (int, error)
+}
diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go
index e1635d8ed1..ce68a3e3fb 100644
--- a/crypto/crypto_test.go
+++ b/crypto/crypto_test.go
@@ -244,3 +244,42 @@ func TestPrivateKeyGeneration(t *testing.T) {
assert.True(t, writtenKey.Equal(readKey))
assert.Equal(t, writtenAddress.String(), readAddress.String())
}
+
+func TestKeccak256(t *testing.T) {
+ t.Parallel()
+
+ cases := []struct {
+ input []byte
+ expect []byte
+ }{
+ {
+ input: nil,
+ expect: types.StringToHash("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(),
+ },
+ {
+ input: types.StringToBytes("0x0"),
+ expect: types.StringToHash("0xbc36789e7a1e281436464229828f817d6612f7b477d66591ff96a9e064bcc98a").Bytes(),
+ },
+ {
+ input: types.StringToBytes("0x00"),
+ expect: types.StringToHash("0xbc36789e7a1e281436464229828f817d6612f7b477d66591ff96a9e064bcc98a").Bytes(),
+ },
+ {
+ input: types.StringToAddress("0x0000000000000000000000000000000000000000").Bytes(),
+ expect: types.StringToHash("0x5380c7b7ae81a58eb98d9c78de4a1fd7fd9535fc953ed2be602daaa41767312a").Bytes(),
+ },
+ {
+ input: types.StringToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes(),
+ expect: types.StringToHash("0x290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563").Bytes(),
+ },
+ {
+ input: types.StringToHash("0x0000000000000000000000000000000000000000000000000000000000000001").Bytes(),
+ expect: types.StringToHash("0xb10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6").Bytes(),
+ },
+ }
+
+ for _, c := range cases {
+ h := Keccak256(c.input)
+ assert.Equal(t, h, c.expect)
+ }
+}
diff --git a/e2e/backup_test.go b/e2e/backup_test.go
index e13342dca4..d8ff7bf456 100644
--- a/e2e/backup_test.go
+++ b/e2e/backup_test.go
@@ -72,8 +72,8 @@ func TestBackup(t *testing.T) {
blockHash := block.Hash
for _, backupFile := range backupFiles {
- os.RemoveAll(path.Join(svr.Config.RootDir, "blockchain"))
- os.RemoveAll(path.Join(svr.Config.RootDir, "trie"))
+ os.RemoveAll(svr.BlockchainDataDir())
+ os.RemoveAll(svr.StateDataDir())
restoreSvr := framework.NewTestServer(t, svr.Config.RootDir, func(config *framework.TestServerConfig) {
*config = *svr.Config
diff --git a/e2e/framework/config.go b/e2e/framework/config.go
index 3079ef8c1a..71368c84ba 100644
--- a/e2e/framework/config.go
+++ b/e2e/framework/config.go
@@ -72,7 +72,7 @@ func (t *TestServerConfig) SetSigner(signer *crypto.EIP155Signer) {
// PrivateKey returns a private key in data directory
func (t *TestServerConfig) PrivateKey() (*ecdsa.PrivateKey, error) {
- return crypto.GenerateOrReadPrivateKey(filepath.Join(t.DataDir(), "consensus", ibft.IbftKeyName))
+ return crypto.GenerateOrReadPrivateKey(filepath.Join(t.DataDir(), _consensusDir, ibft.IbftKeyName))
}
// CALLBACKS //
diff --git a/e2e/framework/testserver.go b/e2e/framework/testserver.go
index a3c3923a04..920f102996 100644
--- a/e2e/framework/testserver.go
+++ b/e2e/framework/testserver.go
@@ -49,6 +49,13 @@ const (
binaryName = "dogechain"
)
+const (
+ _genesisFile = "genesis.json"
+ _blockchainDir = "blockchain"
+ _stateDir = "trie"
+ _consensusDir = "consensus"
+)
+
var lock sync.Mutex
var initialPort = 12000
@@ -164,6 +171,18 @@ func (t *TestServer) IBFTOperator() ibftOp.IbftOperatorClient {
return ibftOp.NewIbftOperatorClient(conn)
}
+func (t *TestServer) GenesisFile() string {
+ return filepath.Join(t.Config.RootDir, _genesisFile)
+}
+
+func (t *TestServer) BlockchainDataDir() string {
+ return filepath.Join(t.Config.RootDir, _blockchainDir)
+}
+
+func (t *TestServer) StateDataDir() string {
+ return filepath.Join(t.Config.RootDir, _stateDir)
+}
+
func (t *TestServer) ReleaseReservedPorts() {
for _, p := range t.Config.ReservedPorts {
if err := p.Close(); err != nil {
@@ -337,7 +356,7 @@ func (t *TestServer) Start(ctx context.Context) error {
args := []string{
serverCmd.Use,
// add custom chain
- "--chain", filepath.Join(t.Config.RootDir, "genesis.json"),
+ "--chain", t.GenesisFile(),
// enable grpc
"--grpc-address", t.GrpcAddr(),
// enable libp2p
@@ -437,7 +456,7 @@ func (t *TestServer) SwitchIBFTType(typ ibft.MechanismType, from uint64, to, dep
args = append(args, commandSlice...)
args = append(args,
// add custom chain
- "--chain", filepath.Join(t.Config.RootDir, "genesis.json"),
+ "--chain", t.GenesisFile(),
"--type", string(typ),
"--from", strconv.FormatUint(from, 10),
)
diff --git a/e2e/reverify_test.go b/e2e/reverify_test.go
index 5eebcd95cc..f4dbfde76e 100644
--- a/e2e/reverify_test.go
+++ b/e2e/reverify_test.go
@@ -2,22 +2,22 @@ package e2e
import (
"context"
- "path/filepath"
"testing"
"time"
"github.com/dogechain-lab/dogechain/chain"
"github.com/dogechain-lab/dogechain/e2e/framework"
- "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/leveldb"
"github.com/dogechain-lab/dogechain/reverify"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
)
func TestReverify(t *testing.T) {
- const toBlock uint64 = 10
-
- const finalToBlock uint64 = 15
+ const (
+ toBlock uint64 = 10
+ finalToBlock uint64 = 15
+ )
svrs := framework.NewTestServers(t, 4, func(config *framework.TestServerConfig) {
config.SetConsensus(framework.ConsensusDev)
@@ -25,51 +25,69 @@ func TestReverify(t *testing.T) {
config.SetDevInterval(1)
})
- {
- errs := framework.WaitForServersToSeal(svrs, toBlock)
- for _, err := range errs {
- assert.NoError(t, err)
+ errs := framework.WaitForServersToSeal(svrs, toBlock)
+ for _, err := range errs {
+ if !assert.NoError(t, err) {
+ t.FailNow()
}
}
svr := svrs[0]
+ // data dir
svrRootDir := svr.Config.RootDir
+ // block number
+ currentBlockHeight, err := svr.JSONRPC().Eth().BlockNumber()
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
+
+ // stop server to make some db corruption
svr.Stop()
- time.Sleep(time.Second * 2)
+ // wait for the process to return leveldb file lock
+ time.Sleep(3 * time.Second)
// open trie database
- leveldbBuilder := kvdb.NewLevelDBBuilder(
- hclog.NewNullLogger(),
- filepath.Join(svrRootDir, "trie"),
- )
-
- // open chain database
- trie, err := leveldbBuilder.Build()
- assert.NoError(t, err)
+ trie, err := leveldb.New(svr.StateDataDir())
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
// corrupt data
{
- iter := trie.Iterator(nil)
+ iter := trie.NewIterator(nil, nil)
assert.NoError(t, iter.Error())
- iter.Last()
- assert.NoError(t, iter.Error())
+ var count uint64 = 0
+ for iter.Next() {
+ count++
+ // do nothing to reach the end
+ if count < currentBlockHeight {
+ continue
+ }
+
+ err := trie.Set(iter.Key(), []byte("corrupted data"))
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
+ }
- for iter.Prev() {
- assert.NoError(t, iter.Error())
- trie.Set(iter.Key(), []byte("corrupted data"))
+ if !assert.NoError(t, iter.Error()) {
+ t.FailNow()
}
+
iter.Release()
- err := trie.Close()
- assert.NoError(t, err)
+ err := trie.Close()
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
}
- genesis, parseErr := chain.Import(
- filepath.Join(svrRootDir, "genesis.json"),
- )
- assert.NoError(t, parseErr)
+ genesis, parseErr := chain.Import(svr.GenesisFile())
+ if !assert.NoError(t, parseErr) {
+ t.FailNow()
+ }
err = reverify.ReverifyChain(
hclog.NewNullLogger(),
@@ -77,8 +95,12 @@ func TestReverify(t *testing.T) {
svrRootDir,
1,
)
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
- assert.NoError(t, err)
+ // wait for the process to return leveldb file lock
+ time.Sleep(3 * time.Second)
resvr := framework.NewTestServer(t, svrRootDir, func(config *framework.TestServerConfig) {
*config = *svr.Config
diff --git a/go.mod b/go.mod
index f578e40b7a..c0304d9ebb 100644
--- a/go.mod
+++ b/go.mod
@@ -32,7 +32,7 @@ require (
github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7
github.com/umbracle/go-eth-bn256 v0.0.0-20190607160430-b36caf4e0f6b
github.com/umbracle/go-web3 v0.0.0-20220224145938-aaa1038c1b69
- golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e
+ golang.org/x/crypto v0.1.0
google.golang.org/grpc v1.45.0
google.golang.org/protobuf v1.28.1
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce
@@ -52,9 +52,9 @@ require (
github.com/umbracle/fastrlp v0.0.0-20220527094140-59d5dd30e722 // indirect
github.com/valyala/fastjson v1.6.3 // indirect
go.uber.org/zap v1.22.0 // indirect
- golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4
+ golang.org/x/sync v0.1.0
golang.org/x/sys v0.5.0 // indirect
- golang.org/x/tools v0.1.12 // indirect
+ golang.org/x/tools v0.1.12
google.golang.org/genproto v0.0.0-20211208223120-3a66f561d7aa // indirect
lukechampine.com/blake3 v1.1.7 // indirect
)
@@ -84,6 +84,8 @@ require (
github.com/multiformats/go-multicodec v0.5.0 // indirect
)
+require github.com/holiman/bloomfilter/v2 v2.0.3
+
require (
github.com/VictoriaMetrics/fastcache v1.6.0
github.com/armon/go-radix v1.0.0 // indirect
@@ -92,7 +94,7 @@ require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/cheekybits/genny v1.0.0 // indirect
github.com/coreos/go-systemd/v22 v22.3.2 // indirect
- github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/davecgh/go-spew v1.1.1
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/flynn/noise v1.0.0 // indirect
@@ -181,7 +183,7 @@ require (
golang.org/x/net v0.7.0 // indirect
golang.org/x/term v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
- golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect
+ golang.org/x/time v0.1.0 // indirect
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
diff --git a/go.sum b/go.sum
index ff37ab356d..fad6daee5a 100644
--- a/go.sum
+++ b/go.sum
@@ -403,6 +403,8 @@ github.com/hashicorp/vault/sdk v0.6.0 h1:6Z+In5DXHiUfZvIZdMx7e2loL1PPyDjA4bVh9ZT
github.com/hashicorp/vault/sdk v0.6.0/go.mod h1:+DRpzoXIdMvKc88R4qxr+edwy/RvH5QK8itmxLiDHLc=
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M=
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
+github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao=
+github.com/holiman/bloomfilter/v2 v2.0.3/go.mod h1:zpoh+gs7qcpqrHr3dB55AMiJwo0iURXE7ZOP9L9hSkA=
github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef h1:A9HsByNhogrvm9cWb28sjiS3i7tcKCkflWFEkHfuAgM=
github.com/howeyc/gopass v0.0.0-20210920133722-c8aef6fb66ef/go.mod h1:lADxMC39cJJqL93Duh1xhAs4I2Zs8mKS89XWXFGp9cs=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
@@ -894,6 +896,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM=
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
+golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -984,6 +988,8 @@ golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
+golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
+golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@@ -1008,8 +1014,9 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
+golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180810173357-98c5dad5d1a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -1086,11 +1093,16 @@ golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
+golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1101,6 +1113,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1108,8 +1121,8 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
-golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs=
-golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA=
+golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/graphql/endpoints.go b/graphql/endpoints.go
index d4fb445c28..17cfe6f524 100644
--- a/graphql/endpoints.go
+++ b/graphql/endpoints.go
@@ -6,8 +6,8 @@ import (
"github.com/dogechain-lab/dogechain/blockchain"
"github.com/dogechain-lab/dogechain/chain"
"github.com/dogechain-lab/dogechain/helper/progress"
- "github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
)
@@ -23,8 +23,8 @@ type ethTxPoolStore interface {
}
type ethStateStore interface {
- GetAccount(root types.Hash, addr types.Address) (*state.Account, error)
- GetStorage(root types.Hash, addr types.Address, slot types.Hash) ([]byte, error)
+ GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error)
+ GetStorage(stateRoot types.Hash, addr types.Address, slot types.Hash) (types.Hash, error)
GetForksInTime(blockNumber uint64) chain.ForksInTime
GetCode(stateRoot types.Hash, account types.Address) ([]byte, error)
}
diff --git a/graphql/graphql.go b/graphql/graphql.go
index e76085e8a3..8faf11c20f 100644
--- a/graphql/graphql.go
+++ b/graphql/graphql.go
@@ -11,7 +11,6 @@ import (
"github.com/dogechain-lab/dogechain/graphql/argtype"
rpc "github.com/dogechain-lab/dogechain/jsonrpc"
"github.com/dogechain-lab/dogechain/types"
- "github.com/dogechain-lab/fastrlp"
)
var (
@@ -44,7 +43,7 @@ func (a *Account) getStateRoot(ctx context.Context) (types.Hash, error) {
header, err := a.getHeaderFromBlockNumberOrHash(&a.blockNrOrHash)
if err != nil {
- return types.ZeroHash, errFetchingHeader
+ return types.Hash{}, errFetchingHeader
}
return header.StateRoot, nil
@@ -165,33 +164,16 @@ func (a *Account) Code(ctx context.Context) (argtype.Bytes, error) {
func (a *Account) Storage(ctx context.Context, args struct{ Slot types.Hash }) (types.Hash, error) {
root, err := a.getStateRoot(ctx)
if err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
// Get the storage for the passed in location
result, err := a.backend.GetStorage(root, a.address, args.Slot)
- if err != nil {
- if errors.Is(err, rpc.ErrStateNotFound) {
- return types.ZeroHash, nil
- }
-
- return types.ZeroHash, err
- }
-
- // Parse the RLP value
- p := &fastrlp.Parser{}
-
- v, err := p.Parse(result)
- if err != nil {
- return types.ZeroHash, nil
- }
-
- data, err := v.Bytes()
- if err != nil {
- return types.ZeroHash, nil
+ if err != nil && errors.Is(err, rpc.ErrStateNotFound) {
+ return result, nil
}
- return types.BytesToHash(data), nil
+ return result, err
}
// Log represents an individual log message. All arguments are mandatory.
@@ -738,7 +720,7 @@ func (b *Block) Number(ctx context.Context) (argtype.Long, error) {
func (b *Block) Hash(ctx context.Context) (types.Hash, error) {
if _, err := b.resolveHeader(ctx); err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
return b.hash, nil
@@ -809,7 +791,7 @@ func (b *Block) Nonce(ctx context.Context) (argtype.Bytes, error) {
func (b *Block) MixHash(ctx context.Context) (types.Hash, error) {
if _, err := b.resolveHeader(ctx); err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
return b.header.MixHash, nil
@@ -817,7 +799,7 @@ func (b *Block) MixHash(ctx context.Context) (types.Hash, error) {
func (b *Block) TransactionsRoot(ctx context.Context) (types.Hash, error) {
if _, err := b.resolveHeader(ctx); err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
return b.header.TxRoot, nil
@@ -825,7 +807,7 @@ func (b *Block) TransactionsRoot(ctx context.Context) (types.Hash, error) {
func (b *Block) StateRoot(ctx context.Context) (types.Hash, error) {
if _, err := b.resolveHeader(ctx); err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
return b.header.StateRoot, nil
@@ -833,7 +815,7 @@ func (b *Block) StateRoot(ctx context.Context) (types.Hash, error) {
func (b *Block) ReceiptsRoot(ctx context.Context) (types.Hash, error) {
if _, err := b.resolveHeader(ctx); err != nil {
- return types.ZeroHash, err
+ return types.Hash{}, err
}
return b.header.ReceiptsRoot, nil
diff --git a/graphql/service.go b/graphql/service.go
index 6f192f3149..c24c5194a2 100644
--- a/graphql/service.go
+++ b/graphql/service.go
@@ -18,6 +18,7 @@ type GraphQLService struct {
config *Config
ui *GraphiQL
handler *handler
+ server *http.Server
}
type Config struct {
@@ -91,11 +92,13 @@ func (svc *GraphQLService) setupHTTP() error {
mux.Handle("/graphql", middlewareFactory(svc.config)(graphqlHandler))
mux.Handle("/graphql/", middlewareFactory(svc.config)(graphqlHandler))
- srv := http.Server{
+ srv := &http.Server{
Handler: mux,
ReadHeaderTimeout: time.Minute,
}
+ svc.server = srv
+
go func() {
if err := srv.Serve(lis); err != nil {
svc.logger.Error("closed http connection", "err", err)
@@ -105,6 +108,17 @@ func (svc *GraphQLService) setupHTTP() error {
return nil
}
+func (svc *GraphQLService) Close() error {
+ if svc.server == nil {
+ return nil
+ }
+
+ err := svc.server.Close()
+ svc.server = nil
+
+ return err
+}
+
type handler struct {
Schema *graphql.Schema
}
diff --git a/helper/common/common.go b/helper/common/common.go
index be5590dd99..4014274ee7 100644
--- a/helper/common/common.go
+++ b/helper/common/common.go
@@ -148,7 +148,7 @@ func createDir(path string) error {
}
if os.IsNotExist(err) {
- if err := os.MkdirAll(path, os.ModePerm); err != nil {
+ if err := os.MkdirAll(path, 0755); err != nil {
return err
}
}
diff --git a/helper/kvdb/batch.go b/helper/kvdb/batch.go
new file mode 100644
index 0000000000..b7da49f7f1
--- /dev/null
+++ b/helper/kvdb/batch.go
@@ -0,0 +1,28 @@
+package kvdb
+
+// IdealBatchSize defines the size of the data batches should ideally add in one
+// write.
+const IdealBatchSize = 100 * 1024
+
+type Batch interface {
+ KVWriter
+
+ // ValueSize retrieves the amount of data queued up for writing.
+ ValueSize() int
+
+ // Write flushes any accumulated data to disk.
+ Write() error
+
+ // Reset resets the batch for reuse.
+ Reset()
+
+ // Replay replays the batch contents.
+ Replay(w KVWriter) error
+}
+
+// Batcher wraps the NewBatch method of a backing data store.
+type Batcher interface {
+ // NewBatch creates a write-only database that buffers changes to its host db
+ // until a final write is called.
+ NewBatch() Batch
+}
diff --git a/helper/kvdb/builder.go b/helper/kvdb/builder.go
deleted file mode 100644
index c2545285c2..0000000000
--- a/helper/kvdb/builder.go
+++ /dev/null
@@ -1,157 +0,0 @@
-package kvdb
-
-import (
- "fmt"
-
- "github.com/hashicorp/go-hclog"
- "github.com/syndtr/goleveldb/leveldb"
- "github.com/syndtr/goleveldb/leveldb/filter"
- "github.com/syndtr/goleveldb/leveldb/opt"
-)
-
-const (
- // minLevelDBCache is the minimum memory allocate to leveldb
- // half write, half read
- minLevelDBCache = 16 // 16 MiB
-
- // minLevelDBHandles is the minimum number of files handles to leveldb open files
- minLevelDBHandles = 16
-
- DefaultLevelDBCache = 1024 // 1 GiB
- DefaultLevelDBHandles = 512 // files handles to leveldb open files
- DefaultLevelDBBloomKeyBits = 2048 // bloom filter bits (256 bytes)
- DefaultLevelDBCompactionTableSize = 4 // 4 MiB
- DefaultLevelDBCompactionTotalSize = 40 // 40 MiB
- DefaultLevelDBNoSync = false
-)
-
-func max(a, b int) int {
- if a > b {
- return a
- }
-
- return b
-}
-
-type LevelDBBuilder interface {
- // set cache size
- SetCacheSize(int) LevelDBBuilder
-
- // set handles
- SetHandles(int) LevelDBBuilder
-
- // set bloom key bits
- SetBloomKeyBits(int) LevelDBBuilder
-
- // set compaction table size
- SetCompactionTableSize(int) LevelDBBuilder
-
- // set compaction table total size
- SetCompactionTotalSize(int) LevelDBBuilder
-
- // set no sync
- SetNoSync(bool) LevelDBBuilder
-
- // build the storage
- Build() (KVBatchStorage, error)
-}
-
-type leveldbBuilder struct {
- logger hclog.Logger
- path string
- options *opt.Options
-}
-
-func (builder *leveldbBuilder) SetCacheSize(cacheSize int) LevelDBBuilder {
- cacheSize = max(cacheSize, minLevelDBCache)
-
- builder.options.BlockCacheCapacity = cacheSize * opt.MiB
-
- builder.logger.Info("leveldb",
- "BlockCacheCapacity", fmt.Sprintf("%d Mib", cacheSize),
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) SetHandles(handles int) LevelDBBuilder {
- builder.options.OpenFilesCacheCapacity = max(handles, minLevelDBHandles)
-
- builder.logger.Info("leveldb",
- "OpenFilesCacheCapacity", builder.options.OpenFilesCacheCapacity,
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) SetBloomKeyBits(bloomKeyBits int) LevelDBBuilder {
- builder.options.Filter = filter.NewBloomFilter(bloomKeyBits)
-
- builder.logger.Info("leveldb",
- "BloomFilter bits", bloomKeyBits,
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) SetCompactionTableSize(compactionTableSize int) LevelDBBuilder {
- builder.options.CompactionTableSize = compactionTableSize * opt.MiB
- builder.options.WriteBuffer = builder.options.CompactionTableSize * 2
-
- builder.logger.Info("leveldb",
- "CompactionTableSize", fmt.Sprintf("%d Mib", compactionTableSize),
- "WriteBuffer", fmt.Sprintf("%d Mib", builder.options.WriteBuffer/opt.MiB),
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) SetCompactionTotalSize(compactionTotalSize int) LevelDBBuilder {
- builder.options.CompactionTotalSize = compactionTotalSize * opt.MiB
-
- builder.logger.Info("leveldb",
- "CompactionTotalSize", fmt.Sprintf("%d Mib", compactionTotalSize),
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) SetNoSync(noSync bool) LevelDBBuilder {
- builder.options.NoSync = noSync
-
- builder.logger.Info("leveldb",
- "NoSync", noSync,
- )
-
- return builder
-}
-
-func (builder *leveldbBuilder) Build() (KVBatchStorage, error) {
- db, err := leveldb.OpenFile(builder.path, builder.options)
- if err != nil {
- return nil, err
- }
-
- return &levelDBKV{db: db}, nil
-}
-
-// NewBuilder creates the new leveldb storage builder
-func NewLevelDBBuilder(logger hclog.Logger, path string) LevelDBBuilder {
- return &leveldbBuilder{
- logger: logger,
- path: path,
- options: &opt.Options{
- OpenFilesCacheCapacity: minLevelDBHandles,
- CompactionTableSize: DefaultLevelDBCompactionTableSize * opt.MiB,
- CompactionTotalSize: DefaultLevelDBCompactionTotalSize * opt.MiB,
- BlockCacheCapacity: minLevelDBCache * opt.MiB,
- WriteBuffer: (DefaultLevelDBCompactionTableSize * 2) * opt.MiB,
- CompactionTableSizeMultiplier: 1.1, // scale size up 1.1 multiple in next level
- Filter: filter.NewBloomFilter(DefaultLevelDBBloomKeyBits),
- NoSync: false,
- BlockSize: 256 * opt.KiB, // default 4kb, but one key-value pair need 0.5kb
- FilterBaseLg: 19, // 512kb
- DisableSeeksCompaction: true,
- },
- }
-}
diff --git a/helper/kvdb/database.go b/helper/kvdb/database.go
new file mode 100644
index 0000000000..fe23737cd7
--- /dev/null
+++ b/helper/kvdb/database.go
@@ -0,0 +1,48 @@
+package kvdb
+
+import "io"
+
+// KVReader wraps the Get method of a backing data store.
+type KVReader interface {
+ // Has retrieves if a key is present in the key-value data store.
+ Has(key []byte) (bool, error)
+ // Get retrieves the given key if it's present in the key-value data store.
+ Get(key []byte) (value []byte, exists bool, err error)
+}
+
+// KVWriter wraps the Put method of a backing data store.
+type KVWriter interface {
+ // Set inserts the given value into the key-value data store.
+ Set(k, v []byte) error
+ // Delete removes the key from the key-value data store.
+ Delete(key []byte) error
+}
+
+// KVBatchStorage is a batch write for leveldb
+type KVBatchStorage interface {
+ KVReader
+ KVWriter
+ Batcher
+ Iteratee
+ io.Closer
+}
+
+// Reader contains the methods required to read data from key-value
+type Reader interface {
+ KVReader
+}
+
+// Writer contains the methods required to write data to key-value
+type Writer interface {
+ KVWriter
+}
+
+// Database contains all the methods required by the high level database to not
+// only access the key-value data store but also the chain freezer.
+type Database interface {
+ Reader
+ Writer
+ Batcher
+ Iteratee
+ io.Closer
+}
diff --git a/helper/kvdb/dbtest/testsuite.go b/helper/kvdb/dbtest/testsuite.go
new file mode 100644
index 0000000000..793d7b54cc
--- /dev/null
+++ b/helper/kvdb/dbtest/testsuite.go
@@ -0,0 +1,397 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package dbtest
+
+import (
+ "bytes"
+ "reflect"
+ "sort"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+)
+
+// TestDatabaseSuite runs a suite of tests against a KVBatchStorage database
+// implementation.
+func TestDatabaseSuite(t *testing.T, New func() kvdb.KVBatchStorage) {
+ t.Run("Iterator", func(t *testing.T) {
+ tests := []struct {
+ content map[string]string
+ prefix string
+ start string
+ order []string
+ }{
+ // Empty databases should be iterable
+ {map[string]string{}, "", "", nil},
+ {map[string]string{}, "non-existent-prefix", "", nil},
+
+ // Single-item databases should be iterable
+ {map[string]string{"key": "val"}, "", "", []string{"key"}},
+ {map[string]string{"key": "val"}, "k", "", []string{"key"}},
+ {map[string]string{"key": "val"}, "l", "", nil},
+
+ // Multi-item databases should be fully iterable
+ {
+ map[string]string{"k1": "v1", "k5": "v5", "k2": "v2", "k4": "v4", "k3": "v3"},
+ "", "",
+ []string{"k1", "k2", "k3", "k4", "k5"},
+ },
+ {
+ map[string]string{"k1": "v1", "k5": "v5", "k2": "v2", "k4": "v4", "k3": "v3"},
+ "k", "",
+ []string{"k1", "k2", "k3", "k4", "k5"},
+ },
+ {
+ map[string]string{"k1": "v1", "k5": "v5", "k2": "v2", "k4": "v4", "k3": "v3"},
+ "l", "",
+ nil,
+ },
+ // Multi-item databases should be prefix-iterable
+ {
+ map[string]string{
+ "ka1": "va1", "ka5": "va5", "ka2": "va2", "ka4": "va4", "ka3": "va3",
+ "kb1": "vb1", "kb5": "vb5", "kb2": "vb2", "kb4": "vb4", "kb3": "vb3",
+ },
+ "ka", "",
+ []string{"ka1", "ka2", "ka3", "ka4", "ka5"},
+ },
+ {
+ map[string]string{
+ "ka1": "va1", "ka5": "va5", "ka2": "va2", "ka4": "va4", "ka3": "va3",
+ "kb1": "vb1", "kb5": "vb5", "kb2": "vb2", "kb4": "vb4", "kb3": "vb3",
+ },
+ "kc", "",
+ nil,
+ },
+ // Multi-item databases should be prefix-iterable with start position
+ {
+ map[string]string{
+ "ka1": "va1", "ka5": "va5", "ka2": "va2", "ka4": "va4", "ka3": "va3",
+ "kb1": "vb1", "kb5": "vb5", "kb2": "vb2", "kb4": "vb4", "kb3": "vb3",
+ },
+ "ka", "3",
+ []string{"ka3", "ka4", "ka5"},
+ },
+ {
+ map[string]string{
+ "ka1": "va1", "ka5": "va5", "ka2": "va2", "ka4": "va4", "ka3": "va3",
+ "kb1": "vb1", "kb5": "vb5", "kb2": "vb2", "kb4": "vb4", "kb3": "vb3",
+ },
+ "ka", "8",
+ nil,
+ },
+ }
+ for i, tt := range tests {
+ // Create the key-value data store
+ db := New()
+ for key, val := range tt.content {
+ if err := db.Set([]byte(key), []byte(val)); err != nil {
+ t.Fatalf("test %d: failed to insert item %s:%s into database: %v", i, key, val, err)
+ }
+ }
+ // Iterate over the database with the given configs and verify the results
+ it, idx := db.NewIterator([]byte(tt.prefix), []byte(tt.start)), 0
+ for it.Next() {
+ if len(tt.order) <= idx {
+ t.Errorf("test %d: prefix=%q more items than expected: checking idx=%d (key %q), expecting len=%d",
+ i, tt.prefix, idx, it.Key(), len(tt.order))
+
+ break
+ }
+ if !bytes.Equal(it.Key(), []byte(tt.order[idx])) {
+ t.Errorf("test %d: item %d: key mismatch: have %s, want %s",
+ i, idx, string(it.Key()), tt.order[idx])
+ }
+ if !bytes.Equal(it.Value(), []byte(tt.content[tt.order[idx]])) {
+ t.Errorf("test %d: item %d: value mismatch: have %s, want %s",
+ i, idx, string(it.Value()), tt.content[tt.order[idx]])
+ }
+ idx++
+ }
+ if err := it.Error(); err != nil {
+ t.Errorf("test %d: iteration failed: %v", i, err)
+ }
+ if idx != len(tt.order) {
+ t.Errorf("test %d: iteration terminated prematurely: have %d, want %d",
+ i, idx, len(tt.order))
+ }
+ db.Close()
+ }
+ })
+
+ t.Run("IteratorWith", func(t *testing.T) {
+ db := New()
+ defer db.Close()
+
+ keys := []string{"1", "2", "3", "4", "6", "10", "11", "12", "20", "21", "22"}
+ sort.Strings(keys) // 1, 10, 11, etc
+
+ for _, k := range keys {
+ if err := db.Set([]byte(k), nil); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ {
+ it := db.NewIterator(nil, nil)
+ got, want := iterateKeys(it), keys
+ if err := it.Error(); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Iterator: got: %s; want: %s", got, want)
+ }
+ }
+
+ {
+ it := db.NewIterator([]byte("1"), nil)
+ got, want := iterateKeys(it), []string{"1", "10", "11", "12"}
+ if err := it.Error(); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("IteratorWith(1,nil): got: %s; want: %s", got, want)
+ }
+ }
+
+ {
+ it := db.NewIterator([]byte("5"), nil)
+ got, want := iterateKeys(it), []string{}
+ if err := it.Error(); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("IteratorWith(5,nil): got: %s; want: %s", got, want)
+ }
+ }
+
+ {
+ it := db.NewIterator(nil, []byte("2"))
+ got, want := iterateKeys(it), []string{"2", "20", "21", "22", "3", "4", "6"}
+ if err := it.Error(); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("IteratorWith(nil,2): got: %s; want: %s", got, want)
+ }
+ }
+
+ {
+ it := db.NewIterator(nil, []byte("5"))
+ got, want := iterateKeys(it), []string{"6"}
+ if err := it.Error(); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("IteratorWith(nil,5): got: %s; want: %s", got, want)
+ }
+ }
+ })
+
+ t.Run("KeyValueOperations", func(t *testing.T) {
+ db := New()
+ defer db.Close()
+
+ key := []byte("foo")
+
+ // if got, err := db.Has(key); err != nil {
+ // t.Error(err)
+ // } else if got {
+ // t.Errorf("wrong value: %t", got)
+ // }
+
+ value := []byte("hello world")
+ if err := db.Set(key, value); err != nil {
+ t.Error(err)
+ }
+
+ // if got, err := db.Has(key); err != nil {
+ // t.Error(err)
+ // } else if !got {
+ // t.Errorf("wrong value: %t", got)
+ // }
+
+ if got, _, err := db.Get(key); err != nil {
+ t.Error(err)
+ } else if !bytes.Equal(got, value) {
+ t.Errorf("wrong value: %q", got)
+ }
+
+ // if err := db.Delete(key); err != nil {
+ // t.Error(err)
+ // }
+
+ // if got, err := db.Has(key); err != nil {
+ // t.Error(err)
+ // } else if got {
+ // t.Errorf("wrong value: %t", got)
+ // }
+ })
+
+ // t.Run("BatchReplay", func(t *testing.T) {
+ // db := New()
+ // defer db.Close()
+
+ // want := []string{"1", "2", "3", "4"}
+ // b := db.NewBatch()
+ // for _, k := range want {
+ // if err := b.Set([]byte(k), nil); err != nil {
+ // t.Fatal(err)
+ // }
+ // }
+
+ // b2 := db.NewBatch()
+ // if err := b.Replay(b2); err != nil {
+ // t.Fatal(err)
+ // }
+
+ // if err := b2.Replay(db); err != nil {
+ // t.Fatal(err)
+ // }
+
+ // it := db.NewIterator(nil, nil)
+ // if got := iterateKeys(it); !reflect.DeepEqual(got, want) {
+ // t.Errorf("got: %s; want: %s", got, want)
+ // }
+ // })
+
+ // t.Run("Snapshot", func(t *testing.T) {
+ // db := New()
+ // defer db.Close()
+
+ // initial := map[string]string{
+ // "k1": "v1", "k2": "v2", "k3": "", "k4": "",
+ // }
+ // for k, v := range initial {
+ // db.Set([]byte(k), []byte(v))
+ // }
+ // snapshot, err := db.NewSnapshot()
+ // if err != nil {
+ // t.Fatal(err)
+ // }
+ // for k, v := range initial {
+ // got, err := snapshot.Get([]byte(k))
+ // if err != nil {
+ // t.Fatal(err)
+ // }
+ // if !bytes.Equal(got, []byte(v)) {
+ // t.Fatalf("Unexpected value want: %v, got %v", v, got)
+ // }
+ // }
+
+ // // Flush more modifications into the database, ensure the snapshot
+ // // isn't affected.
+ // var (
+ // update = map[string]string{"k1": "v1-b", "k3": "v3-b"}
+ // insert = map[string]string{"k5": "v5-b"}
+ // delete = map[string]string{"k2": ""}
+ // )
+ // for k, v := range update {
+ // db.Set([]byte(k), []byte(v))
+ // }
+ // for k, v := range insert {
+ // db.Set([]byte(k), []byte(v))
+ // }
+ // for k := range delete {
+ // db.Delete([]byte(k))
+ // }
+ // for k, v := range initial {
+ // got, err := snapshot.Get([]byte(k))
+ // if err != nil {
+ // t.Fatal(err)
+ // }
+ // if !bytes.Equal(got, []byte(v)) {
+ // t.Fatalf("Unexpected value want: %v, got %v", v, got)
+ // }
+ // }
+ // for k := range insert {
+ // got, err := snapshot.Get([]byte(k))
+ // if err == nil || len(got) != 0 {
+ // t.Fatal("Unexpected value")
+ // }
+ // }
+ // for k := range delete {
+ // got, err := snapshot.Get([]byte(k))
+ // if err != nil || len(got) == 0 {
+ // t.Fatal("Unexpected deletion")
+ // }
+ // }
+ // })
+
+ t.Run("Batch", func(t *testing.T) {
+ db := New()
+ defer db.Close()
+
+ b := db.NewBatch()
+ for _, k := range []string{"1", "2", "3", "4"} {
+ if err := b.Set([]byte(k), nil); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // if has, err := db.Has([]byte("1")); err != nil {
+ // t.Fatal(err)
+ // } else if has {
+ // t.Error("db contains element before batch write")
+ // }
+
+ if err := b.Write(); err != nil {
+ t.Fatal(err)
+ }
+
+ {
+ it := db.NewIterator(nil, nil)
+ if got, want := iterateKeys(it), []string{"1", "2", "3", "4"}; !reflect.DeepEqual(got, want) {
+ t.Errorf("got: %s; want: %s", got, want)
+ }
+ }
+
+ // b.Reset()
+
+ // Mix writes and deletes in batch
+ b.Set([]byte("5"), nil)
+ // b.Delete([]byte("1"))
+ b.Set([]byte("6"), nil)
+ // b.Delete([]byte("3"))
+ b.Set([]byte("3"), nil)
+
+ if err := b.Write(); err != nil {
+ t.Fatal(err)
+ }
+
+ {
+ it := db.NewIterator(nil, nil)
+
+ // if got, want := iterateKeys(it), []string{"2", "3", "4", "5", "6"}; !reflect.DeepEqual(got, want) {
+ if got, want := iterateKeys(it), []string{"1", "2", "3", "4", "5", "6"}; !reflect.DeepEqual(got, want) {
+ t.Errorf("got: %s; want: %s", got, want)
+ }
+ }
+ })
+}
+
+func iterateKeys(it kvdb.Iterator) []string {
+ keys := []string{}
+ for it.Next() {
+ keys = append(keys, string(it.Key()))
+ }
+
+ sort.Strings(keys)
+ it.Release()
+
+ return keys
+}
diff --git a/helper/kvdb/iterator.go b/helper/kvdb/iterator.go
new file mode 100644
index 0000000000..7bf20a355f
--- /dev/null
+++ b/helper/kvdb/iterator.go
@@ -0,0 +1,36 @@
+package kvdb
+
+type Iterator interface {
+ // Next moves the iterator to the next key/value pair.
+ // It returns false if the iterator is exhausted.
+ Next() bool
+
+ // Key returns the key of the current key/value pair, or nil if done.
+ // The caller should not modify the contents of the returned slice, and
+ // its contents may change on the next call to any 'seeks method'.
+ Key() []byte
+
+ // Value returns the value of the current key/value pair, or nil if done.
+ // The caller should not modify the contents of the returned slice, and
+ // its contents may change on the next call to any 'seeks method'.
+ Value() []byte
+
+ // Release releases associated resources. Release should always success
+ // and can be called multiple times without causing error.
+ Release()
+
+ // Error returns any accumulated error. Exhausting all the key/value pairs
+ // is not considered to be an error.
+ Error() error
+}
+
+// Iteratee wraps the NewIterator methods of a backing data store.
+type Iteratee interface {
+ // NewIterator creates a binary-alphabetical iterator over a subset
+ // of database content with a particular key prefix, starting at a particular
+ // initial key (or after, if it does not exist).
+ //
+ // Note: This method assumes that the prefix is NOT part of the start, so there's
+ // no need for the caller to prepend the prefix to the start
+ NewIterator(prefix, start []byte) Iterator
+}
diff --git a/helper/kvdb/kvdb.go b/helper/kvdb/kvdb.go
deleted file mode 100644
index 30b1aa1d1e..0000000000
--- a/helper/kvdb/kvdb.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package kvdb
-
-type KVBatch interface {
- Set(k, v []byte)
- Write() error
-}
-
-type KVIteratorRange struct {
- Start []byte
- Limit []byte
-}
-
-type KVIterator interface {
- // First moves the iterator to the first key/value pair. If the iterator
- // only contains one key/value pair then First and Last would moves
- // to the same key/value pair.
- // It returns whether such pair exist.
- First() bool
-
- // Last moves the iterator to the last key/value pair. If the iterator
- // only contains one key/value pair then First and Last would moves
- // to the same key/value pair.
- // It returns whether such pair exist.
- Last() bool
-
- // Seek moves the iterator to the first key/value pair whose key is greater
- // than or equal to the given key.
- // It returns whether such pair exist.
- //
- // It is safe to modify the contents of the argument after Seek returns.
- Seek(key []byte) bool
-
- // Next moves the iterator to the next key/value pair.
- // It returns false if the iterator is exhausted.
- Next() bool
-
- // Prev moves the iterator to the previous key/value pair.
- // It returns false if the iterator is exhausted.
- Prev() bool
-
- // Key returns the key of the current key/value pair, or nil if done.
- // The caller should not modify the contents of the returned slice, and
- // its contents may change on the next call to any 'seeks method'.
- Key() []byte
-
- // Value returns the value of the current key/value pair, or nil if done.
- // The caller should not modify the contents of the returned slice, and
- // its contents may change on the next call to any 'seeks method'.
- Value() []byte
-
- // Release releases associated resources. Release should always success
- // and can be called multiple times without causing error.
- Release()
-
- // Error returns any accumulated error. Exhausting all the key/value pairs
- // is not considered to be an error.
- Error() error
-}
-
-// KVStorage is a k/v storage on memory or leveldb
-type KVStorage interface {
- Set(k, v []byte) error
- Get(k []byte) ([]byte, bool, error)
-
- Close() error
-}
-
-// KVBatchStorage is a batch write for leveldb
-type KVBatchStorage interface {
- KVStorage
-
- Iterator(*KVIteratorRange) KVIterator
-
- Batch() KVBatch
-}
diff --git a/helper/kvdb/leveldb.go b/helper/kvdb/leveldb.go
deleted file mode 100644
index cbbb417b45..0000000000
--- a/helper/kvdb/leveldb.go
+++ /dev/null
@@ -1,67 +0,0 @@
-package kvdb
-
-import (
- "errors"
-
- "github.com/syndtr/goleveldb/leveldb"
- "github.com/syndtr/goleveldb/leveldb/util"
-)
-
-type levelBatch struct {
- db *leveldb.DB
- batch *leveldb.Batch
-}
-
-func (b *levelBatch) Set(k, v []byte) {
- b.batch.Put(k, v)
-}
-
-func (b *levelBatch) Write() error {
- return b.db.Write(b.batch, nil)
-}
-
-// levelDBKV is the leveldb implementation of the kv storage
-type levelDBKV struct {
- db *leveldb.DB
-}
-
-func (kv *levelDBKV) Batch() KVBatch {
- return &levelBatch{db: kv.db, batch: &leveldb.Batch{}}
-}
-
-func (kv *levelDBKV) Iterator(Range *KVIteratorRange) KVIterator {
- if Range == nil {
- return kv.db.NewIterator(nil, nil)
- }
-
- return kv.db.NewIterator(&util.Range{
- Start: Range.Start,
- Limit: Range.Limit,
- }, nil)
-}
-
-// Set sets the key-value pair in leveldb storage
-func (kv *levelDBKV) Set(p []byte, v []byte) error {
- return kv.db.Put(p, v, nil)
-}
-
-// Get retrieves the key-value pair in leveldb storage
-func (kv *levelDBKV) Get(p []byte) ([]byte, bool, error) {
- data, err := kv.db.Get(p, nil)
- if err != nil {
- if errors.Is(err, leveldb.ErrNotFound) {
- return nil, false, nil
- } else if errors.Is(err, leveldb.ErrClosed) {
- return nil, false, nil
- } else {
- panic(err)
- }
- }
-
- return data, true, nil
-}
-
-// Close closes the leveldb storage instance
-func (kv *levelDBKV) Close() error {
- return kv.db.Close()
-}
diff --git a/helper/kvdb/leveldb/leveldb.go b/helper/kvdb/leveldb/leveldb.go
new file mode 100644
index 0000000000..2608fd3080
--- /dev/null
+++ b/helper/kvdb/leveldb/leveldb.go
@@ -0,0 +1,176 @@
+package leveldb
+
+import (
+ "errors"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/hashicorp/go-hclog"
+ "github.com/syndtr/goleveldb/leveldb"
+ "github.com/syndtr/goleveldb/leveldb/opt"
+ "github.com/syndtr/goleveldb/leveldb/util"
+)
+
+const (
+ // base block size
+ blockSize = 2 * opt.MiB // default 4kb, but one key-value pair need 0.5kb
+
+ // minCache is the minimum memory allocate to leveldb
+ // half write, half read
+ minCache = 16 // 16 MiB
+
+ // minHandles is the minimum number of files handles to leveldb open files
+ minHandles = 16
+
+ DefaultCache = 1024 // 1 GiB
+ DefaultHandles = 512 // files handles to leveldb open files
+ DefaultBloomKeyBits = 2048 // bloom filter bits (256 bytes)
+ DefaultCompactionTableSize = 4 // 4 MiB
+ DefaultCompactionTotalSize = 40 // 40 MiB
+ DefaultNoSyncFlag = false // false - sync write, true - async write
+)
+
+type batch struct {
+ db *leveldb.DB
+ batch *leveldb.Batch
+ size int // counting batch size
+}
+
+func (b *batch) Set(k, v []byte) error {
+ b.batch.Put(k, v)
+ b.size += len(k) + len(v)
+
+ return nil
+}
+
+func (b *batch) Delete(k []byte) error {
+ b.batch.Delete(k)
+ b.size += len(k)
+
+ return nil
+}
+
+// ValueSize retrieves the amount of data queued up for writing.
+func (b *batch) ValueSize() int {
+ return b.size
+}
+
+func (b *batch) Write() error {
+ return b.db.Write(b.batch, nil)
+}
+
+// Reset resets the batch for reuse.
+func (b *batch) Reset() {
+ b.batch.Reset()
+ b.size = 0
+}
+
+// Replay replays the batch contents.
+func (b *batch) Replay(w kvdb.KVWriter) error {
+ return b.batch.Replay(&replayer{writer: w})
+}
+
+// replayer is a small wrapper to implement the correct replay methods.
+type replayer struct {
+ writer kvdb.KVWriter
+ failure error
+}
+
+// Put inserts the given value into the key-value data store.
+func (r *replayer) Put(key, value []byte) {
+ // If the replay already failed, stop executing ops
+ if r.failure != nil {
+ return
+ }
+
+ r.failure = r.writer.Set(key, value)
+}
+
+// Delete removes the key from the key-value data store.
+func (r *replayer) Delete(key []byte) {
+ // If the replay already failed, stop executing ops
+ if r.failure != nil {
+ return
+ }
+
+ r.failure = r.writer.Delete(key)
+}
+
+// database is the leveldb implementation of the kv storage
+type database struct {
+ db *leveldb.DB
+
+ logger kvdb.Logger
+}
+
+func (kv *database) NewBatch() kvdb.Batch {
+ return &batch{db: kv.db, batch: &leveldb.Batch{}}
+}
+
+// bytesPrefixRange returns key range that satisfy
+// - the given prefix, and
+// - the given seek position
+func bytesPrefixRange(prefix, start []byte) *util.Range {
+ r := util.BytesPrefix(prefix)
+ r.Start = append(r.Start, start...)
+
+ return r
+}
+
+func (kv *database) NewIterator(prefix, start []byte) kvdb.Iterator {
+ return kv.db.NewIterator(bytesPrefixRange(prefix, start), nil)
+}
+
+// Set sets the key-value pair in leveldb storage
+func (kv *database) Set(p []byte, v []byte) error {
+ return kv.db.Put(p, v, nil)
+}
+
+func (kv *database) Delete(p []byte) error {
+ return kv.db.Delete(p, nil)
+}
+
+func (kv *database) Has(p []byte) (bool, error) {
+ return kv.db.Has(p, nil)
+}
+
+// Get retrieves the key-value pair in leveldb storage
+func (kv *database) Get(p []byte) ([]byte, bool, error) {
+ data, err := kv.db.Get(p, nil)
+ if err != nil {
+ if errors.Is(err, leveldb.ErrNotFound) {
+ return nil, false, nil
+ } else if errors.Is(err, leveldb.ErrClosed) {
+ return nil, false, err
+ } else {
+ panic(err)
+ }
+ }
+
+ return data, true, nil
+}
+
+// Close closes the leveldb storage instance
+func (kv *database) Close() error {
+ return kv.db.Close()
+}
+
+func New(file string, options ...Option) (kvdb.Database, error) {
+ o := &dbOption{
+ logger: hclog.NewNullLogger(),
+ options: defaultLevelDBOptions(),
+ }
+
+ if err := handleOptions(o, options); err != nil {
+ return nil, err
+ }
+
+ db, err := leveldb.OpenFile(file, o.options)
+ if err != nil {
+ return nil, err
+ }
+
+ return &database{
+ db: db,
+ logger: o.logger,
+ }, nil
+}
diff --git a/helper/kvdb/leveldb/leveldb_test.go b/helper/kvdb/leveldb/leveldb_test.go
new file mode 100644
index 0000000000..ded0ffd01a
--- /dev/null
+++ b/helper/kvdb/leveldb/leveldb_test.go
@@ -0,0 +1,25 @@
+package leveldb
+
+import (
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/dbtest"
+ "github.com/syndtr/goleveldb/leveldb"
+ "github.com/syndtr/goleveldb/leveldb/storage"
+)
+
+func TestLevelDB(t *testing.T) {
+ t.Run("DatabaseSuite", func(t *testing.T) {
+ dbtest.TestDatabaseSuite(t, func() kvdb.KVBatchStorage {
+ db, err := leveldb.Open(storage.NewMemStorage(), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return &database{
+ db: db,
+ }
+ })
+ })
+}
diff --git a/helper/kvdb/leveldb/option.go b/helper/kvdb/leveldb/option.go
new file mode 100644
index 0000000000..433220f31f
--- /dev/null
+++ b/helper/kvdb/leveldb/option.go
@@ -0,0 +1,189 @@
+package leveldb
+
+import (
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/hashicorp/go-hclog"
+ "github.com/syndtr/goleveldb/leveldb/filter"
+ "github.com/syndtr/goleveldb/leveldb/opt"
+)
+
+type optionType string
+
+const (
+ optionArg optionType = "FuncArgument" // Function argument
+)
+
+const (
+ bloomKeyBits = "bloomKeyBits"
+ cacheSize = "cacheSize"
+ compactionTableSize = "compactionTableSize"
+ compactionTotalSize = "compactionTotalSize"
+ handles = "handles"
+ logger = "logger"
+ noSync = "noSync"
+ readOnly = "readOnly"
+)
+
+type (
+ optionValue struct {
+ Value interface{}
+ Type optionType
+ }
+
+ // Option leveldb option
+ Option func(map[string]optionValue) error
+)
+
+func addArg(key string, value interface{}) Option {
+ return func(params map[string]optionValue) error {
+ if value == nil {
+ return nil
+ }
+
+ params[key] = optionValue{value, optionArg}
+
+ return nil
+ }
+}
+
+func addArgError(err error) Option {
+ return func(map[string]optionValue) error {
+ return err
+ }
+}
+
+// SetBloomKeyBits sets bloom filter bits per key
+func SetBloomKeyBits(v int) Option {
+ if v <= 0 {
+ return addArgError(fmt.Errorf("%s value must greater than 0", bloomKeyBits))
+ }
+
+ return addArg(bloomKeyBits, v)
+}
+
+// SetCacheSize sets the cache size in MiB
+func SetCacheSize(v int) Option {
+ if v <= 0 {
+ return addArgError(fmt.Errorf("%s value must greater than 0 MiB", cacheSize))
+ }
+
+ return addArg(cacheSize, v)
+}
+
+// SetCompactionTableSize sets compaction table size in MiB and
+// a write buffer twice the size
+//
+// It limits size of 'sorted table' that compaction generates.
+func SetCompactionTableSize(v int) Option {
+ if v <= 0 {
+ return addArgError(fmt.Errorf("%s value must greater than 0 MiB", compactionTableSize))
+ }
+
+ return addArg(compactionTableSize, v)
+}
+
+// CompactionTotalSize sets total size of compaction table size
+// in MiB
+//
+// It limits total size of 'sorted table' for each level.
+func SetCompactionTotalSize(v int) Option {
+ if v <= 0 {
+ return addArgError(fmt.Errorf("%s value must greater than 0 MiB", compactionTotalSize))
+ }
+
+ return addArg(compactionTotalSize, v)
+}
+
+// SetHandles sets the handles (file discriptor count)
+func SetHandles(v int) Option {
+ if v <= 0 {
+ return addArgError(fmt.Errorf("%s value must greater than 0", handles))
+ }
+
+ return addArg(handles, v)
+}
+
+// SetLogger sets the outside logger to it
+//
+// The default one print out nothing
+func SetLogger(v kvdb.Logger) Option {
+ if v == nil {
+ v = hclog.NewNullLogger()
+ }
+
+ return addArg(logger, v)
+}
+
+// NoSync allows completely disable fsync
+func SetNoSync(v bool) Option {
+ return addArg(noSync, v)
+}
+
+func SetReadonly(v bool) Option {
+ return addArg(readOnly, v)
+}
+
+func defaultLevelDBOptions() *opt.Options {
+ return &opt.Options{
+ OpenFilesCacheCapacity: minHandles,
+ CompactionTableSize: DefaultCompactionTableSize * opt.MiB,
+ CompactionTotalSize: DefaultCompactionTotalSize * opt.MiB,
+ BlockCacheCapacity: minCache * opt.MiB,
+ WriteBuffer: (DefaultCompactionTableSize * 2) * opt.MiB,
+ CompactionTableSizeMultiplier: 1.1, // scale size up 1.1 multiple in next level
+ Filter: filter.NewBloomFilter(DefaultBloomKeyBits),
+ NoSync: false,
+ BlockSize: blockSize,
+ FilterBaseLg: 19, // 512kb
+ DisableSeeksCompaction: true,
+ }
+}
+
+type dbOption struct {
+ logger kvdb.Logger
+ options *opt.Options
+}
+
+func handleOptions(o *dbOption, options []Option) error {
+ params := map[string]optionValue{}
+
+ for _, option := range options {
+ if option != nil {
+ if err := option(params); err != nil {
+ return err
+ }
+ }
+ }
+
+PARAM_LOOP:
+ for k, v := range params {
+ //nolint:forcetypeassert
+ switch k {
+ case bloomKeyBits:
+ o.options.Filter = filter.NewBloomFilter(v.Value.(int))
+ case cacheSize:
+ o.options.BlockCacheCapacity = v.Value.(int) * opt.MiB
+ case compactionTableSize:
+ o.options.CompactionTableSize = v.Value.(int) * opt.MiB
+ o.options.WriteBuffer = o.options.CompactionTableSize * 2
+ case compactionTotalSize:
+ o.options.CompactionTotalSize = v.Value.(int) * opt.MiB
+ case handles:
+ o.options.OpenFilesCacheCapacity = v.Value.(int)
+ case logger:
+ o.logger = v.Value.(kvdb.Logger)
+ case noSync:
+ o.options.NoSync = v.Value.(bool)
+ case readOnly:
+ o.options.ReadOnly = v.Value.(bool)
+ default:
+ continue PARAM_LOOP
+ }
+
+ o.logger.Info("set leveldb option", "key", k, "value", v.Value)
+ }
+
+ return nil
+}
diff --git a/helper/kvdb/leveldb_test.go b/helper/kvdb/leveldb_test.go
deleted file mode 100644
index 81c3525b0f..0000000000
--- a/helper/kvdb/leveldb_test.go
+++ /dev/null
@@ -1,161 +0,0 @@
-package kvdb
-
-import (
- "encoding/binary"
- "io/ioutil"
- "math/rand"
- "os"
- "testing"
- "time"
-
- "github.com/hashicorp/go-hclog"
- "github.com/stretchr/testify/assert"
-)
-
-func createTestDB(t *testing.T) KVBatchStorage {
- t.Helper()
-
- tempDir, err := ioutil.TempDir("/tmp", "leveldb-")
- assert.NoError(t, err)
-
- db, err := NewLevelDBBuilder(
- hclog.NewNullLogger(),
- tempDir,
- ).Build()
- if err != nil {
- t.Fatal(err)
- }
-
- t.Cleanup(func() {
- os.RemoveAll(tempDir)
- })
-
- return db
-}
-
-func TestLevelDB(t *testing.T) {
- t.Parallel()
-
- t.Run("test KVStorage Get/Set", func(t *testing.T) {
- t.Parallel()
-
- db := createTestDB(t)
- defer db.Close()
-
- if err := db.Set([]byte("hello"), []byte("world")); err != nil {
- t.Fatal(err)
- }
-
- v, exist, err := db.Get([]byte("hello"))
- if err != nil {
- t.Fatal(err)
- }
-
- assert.True(t, exist)
-
- if string(v) != "world" {
- t.Fatal("value not equal")
- }
- })
-
- t.Run("test KVStorage Batch Write", func(t *testing.T) {
- t.Parallel()
-
- seed := rand.NewSource(time.Now().Unix())
-
- db := createTestDB(t)
- defer db.Close()
-
- keys := [][]byte{{}}
- values := [][]byte{{}}
-
- {
- batch := db.Batch()
- r := rand.New(seed)
-
- for i := 0; i < 100; i++ {
- key := make([]byte, 32)
- r.Read(key)
-
- keys = append(keys, key)
-
- value := make([]byte, 128)
- r.Read(value)
-
- values = append(values, value)
-
- batch.Set(key, value)
- }
-
- if err := batch.Write(); err != nil {
- t.Fatal(err)
- }
- }
-
- {
- for i := 1; i < len(keys); i++ {
- val, exist, err := db.Get(keys[i])
- if err != nil {
- t.Fatal(err)
- }
-
- assert.True(t, exist)
- assert.Equal(t, values[i], val)
- }
- }
- })
-
- t.Run("test KVStorage Iteration", func(t *testing.T) {
- t.Parallel()
-
- seed := rand.NewSource(time.Now().Unix())
-
- db := createTestDB(t)
- defer db.Close()
-
- keys := [][]byte{{}}
- values := [][]byte{{}}
-
- {
- batch := db.Batch()
- r := rand.New(seed)
-
- for i := 0; i < 100; i++ {
- key := make([]byte, 32)
- r.Read(key)
-
- prefix := make([]byte, 4)
- binary.LittleEndian.PutUint32(prefix, uint32(i))
-
- key = append(prefix[:], key...)
-
- keys = append(keys, key)
-
- value := make([]byte, 128)
- r.Read(value)
-
- values = append(values, value)
-
- batch.Set(key, value)
- }
-
- if err := batch.Write(); err != nil {
- t.Fatal(err)
- }
- }
-
- {
- iter := db.Iterator(nil)
- defer iter.Release()
-
- iter.First()
-
- for i := 1; i < len(keys); i++ {
- assert.Equal(t, keys[i], iter.Key())
- assert.Equal(t, values[i], iter.Value())
-
- iter.Next()
- }
- }
- })
-}
diff --git a/helper/kvdb/logger.go b/helper/kvdb/logger.go
new file mode 100644
index 0000000000..77d46b76f8
--- /dev/null
+++ b/helper/kvdb/logger.go
@@ -0,0 +1,16 @@
+package kvdb
+
+// Logger describes the interface that must be implemented by all loggers.
+type Logger interface {
+ // Emit a message and key/value pairs at the DEBUG level
+ Debug(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the INFO level
+ Info(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the WARN level
+ Warn(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the ERROR level
+ Error(msg string, args ...interface{})
+}
diff --git a/helper/kvdb/memorydb/memorydb.go b/helper/kvdb/memorydb/memorydb.go
new file mode 100644
index 0000000000..36706c6704
--- /dev/null
+++ b/helper/kvdb/memorydb/memorydb.go
@@ -0,0 +1,394 @@
+package memorydb
+
+import (
+ "errors"
+ "sort"
+ "strings"
+ "sync"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+var (
+ // errMemorydbClosed is returned if a memory database was already closed at the
+ // invocation of a data access operation.
+ errMemorydbClosed = errors.New("database closed")
+
+ // errMemorydbNotFound is returned if a key is requested that is not found in
+ // the provided memory database.
+ errMemorydbNotFound = errors.New("not found")
+
+ // errSnapshotReleased is returned if callers want to retrieve data from a
+ // released snapshot.
+ errSnapshotReleased = errors.New("snapshot released")
+)
+
+// Database is an ephemeral key-value store. Apart from basic data storage
+// functionality it also supports batch writes and iterating over the keyspace in
+// binary-alphabetical order.
+type Database struct {
+ db map[string][]byte
+ lock sync.RWMutex
+}
+
+// New returns a wrapped map with all the required database interface methods
+// implemented.
+func New() *Database {
+ return &Database{
+ db: make(map[string][]byte),
+ }
+}
+
+// NewWithCap returns a wrapped map pre-allocated to the provided capacity with
+// all the required database interface methods implemented.
+func NewWithCap(size int) *Database {
+ return &Database{
+ db: make(map[string][]byte, size),
+ }
+}
+
+// Close deallocates the internal map and ensures any consecutive data access op
+// fails with an error.
+func (db *Database) Close() error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.db = nil
+
+ return nil
+}
+
+// Has retrieves if a key is present in the key-value store.
+func (db *Database) Has(key []byte) (bool, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ if db.db == nil {
+ return false, errMemorydbClosed
+ }
+
+ _, ok := db.db[string(key)]
+
+ return ok, nil
+}
+
+// Get retrieves the given key if it's present in the key-value store.
+func (db *Database) Get(key []byte) ([]byte, bool, error) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ if db.db == nil {
+ return nil, false, errMemorydbClosed
+ }
+
+ if entry, ok := db.db[string(key)]; ok {
+ return types.CopyBytes(entry), true, nil
+ }
+
+ return nil, false, errMemorydbNotFound
+}
+
+// Set inserts the given value into the key-value store.
+func (db *Database) Set(key []byte, value []byte) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ if db.db == nil {
+ return errMemorydbClosed
+ }
+
+ db.db[string(key)] = types.CopyBytes(value)
+
+ return nil
+}
+
+// Delete removes the key from the key-value store.
+func (db *Database) Delete(key []byte) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ if db.db == nil {
+ return errMemorydbClosed
+ }
+
+ delete(db.db, string(key))
+
+ return nil
+}
+
+// NewBatch creates a write-only key-value store that buffers changes to its host
+// database until a final write is called.
+func (db *Database) NewBatch() kvdb.Batch {
+ return &batch{
+ db: db,
+ }
+}
+
+// NewIterator creates a binary-alphabetical iterator over a subset
+// of database content with a particular key prefix, starting at a particular
+// initial key (or after, if it does not exist).
+func (db *Database) NewIterator(prefix []byte, start []byte) kvdb.Iterator {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ var (
+ pr = string(prefix)
+ st = string(append(prefix, start...))
+ keys = make([]string, 0, len(db.db))
+ values = make([][]byte, 0, len(db.db))
+ )
+
+ // Collect the keys from the memory database corresponding to the given prefix
+ // and start
+ for key := range db.db {
+ if !strings.HasPrefix(key, pr) {
+ continue
+ }
+
+ if key >= st {
+ keys = append(keys, key)
+ }
+ }
+
+ // Sort the items and retrieve the associated values
+ sort.Strings(keys)
+
+ for _, key := range keys {
+ values = append(values, db.db[key])
+ }
+
+ return &iterator{
+ index: -1,
+ keys: keys,
+ values: values,
+ }
+}
+
+// NewSnapshot creates a database snapshot based on the current state.
+// The created snapshot will not be affected by all following mutations
+// happened on the database.
+func (db *Database) NewSnapshot() (kvdb.Snapshot, error) {
+ return newSnapshot(db), nil
+}
+
+// Stat returns a particular internal stat of the database.
+func (db *Database) Stat(property string) (string, error) {
+ return "", errors.New("unknown property")
+}
+
+// Compact is not supported on a memory database, but there's no need either as
+// a memory database doesn't waste space anyway.
+func (db *Database) Compact(start []byte, limit []byte) error {
+ return nil
+}
+
+// Len returns the number of entries currently present in the memory database.
+//
+// Note, this method is only used for testing (i.e. not public in general) and
+// does not have explicit checks for closed-ness to allow simpler testing code.
+func (db *Database) Len() int {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ return len(db.db)
+}
+
+// keyvalue is a key-value tuple tagged with a deletion field to allow creating
+// memory-database write batches.
+type keyvalue struct {
+ key []byte
+ value []byte
+ delete bool
+}
+
+// batch is a write-only memory batch that commits changes to its host
+// database when Write is called. A batch cannot be used concurrently.
+type batch struct {
+ db *Database
+ writes []keyvalue
+ size int
+}
+
+// Set inserts the given value into the batch for later committing.
+func (b *batch) Set(key, value []byte) error {
+ b.writes = append(b.writes, keyvalue{types.CopyBytes(key), types.CopyBytes(value), false})
+ b.size += len(key) + len(value)
+
+ return nil
+}
+
+// Delete inserts the a key removal into the batch for later committing.
+func (b *batch) Delete(key []byte) error {
+ b.writes = append(b.writes, keyvalue{types.CopyBytes(key), nil, true})
+ b.size += len(key)
+
+ return nil
+}
+
+// ValueSize retrieves the amount of data queued up for writing.
+func (b *batch) ValueSize() int {
+ return b.size
+}
+
+// Write flushes any accumulated data to the memory database.
+func (b *batch) Write() error {
+ b.db.lock.Lock()
+ defer b.db.lock.Unlock()
+
+ for _, keyvalue := range b.writes {
+ if keyvalue.delete {
+ delete(b.db.db, string(keyvalue.key))
+
+ continue
+ }
+
+ b.db.db[string(keyvalue.key)] = keyvalue.value
+ }
+
+ return nil
+}
+
+// Reset resets the batch for reuse.
+func (b *batch) Reset() {
+ b.writes = b.writes[:0]
+ b.size = 0
+}
+
+// Replay replays the batch contents.
+func (b *batch) Replay(w kvdb.KVWriter) error {
+ for _, keyvalue := range b.writes {
+ if keyvalue.delete {
+ if err := w.Delete(keyvalue.key); err != nil {
+ return err
+ }
+
+ continue
+ }
+
+ if err := w.Set(keyvalue.key, keyvalue.value); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// iterator can walk over the (potentially partial) keyspace of a memory key
+// value store. Internally it is a deep copy of the entire iterated state,
+// sorted by keys.
+type iterator struct {
+ index int
+ keys []string
+ values [][]byte
+}
+
+// Next moves the iterator to the next key/value pair. It returns whether the
+// iterator is exhausted.
+func (it *iterator) Next() bool {
+ // Short circuit if iterator is already exhausted in the forward direction.
+ if it.index >= len(it.keys) {
+ return false
+ }
+
+ it.index += 1
+
+ return it.index < len(it.keys)
+}
+
+// Error returns any accumulated error. Exhausting all the key/value pairs
+// is not considered to be an error. A memory iterator cannot encounter errors.
+func (it *iterator) Error() error {
+ return nil
+}
+
+// Key returns the key of the current key/value pair, or nil if done. The caller
+// should not modify the contents of the returned slice, and its contents may
+// change on the next call to Next.
+func (it *iterator) Key() []byte {
+ // Short circuit if iterator is not in a valid position
+ if it.index < 0 || it.index >= len(it.keys) {
+ return nil
+ }
+
+ return []byte(it.keys[it.index])
+}
+
+// Value returns the value of the current key/value pair, or nil if done. The
+// caller should not modify the contents of the returned slice, and its contents
+// may change on the next call to Next.
+func (it *iterator) Value() []byte {
+ // Short circuit if iterator is not in a valid position
+ if it.index < 0 || it.index >= len(it.keys) {
+ return nil
+ }
+
+ return it.values[it.index]
+}
+
+// Release releases associated resources. Release should always succeed and can
+// be called multiple times without causing error.
+func (it *iterator) Release() {
+ it.index, it.keys, it.values = -1, nil, nil
+}
+
+// snapshot wraps a batch of key-value entries deep copied from the in-memory
+// database for implementing the Snapshot interface.
+type snapshot struct {
+ db map[string][]byte
+ lock sync.RWMutex
+}
+
+// newSnapshot initializes the snapshot with the given database instance.
+func newSnapshot(db *Database) *snapshot {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ copied := make(map[string][]byte)
+ for key, val := range db.db {
+ copied[key] = types.CopyBytes(val)
+ }
+
+ return &snapshot{db: copied}
+}
+
+// Has retrieves if a key is present in the snapshot backing by a key-value
+// data store.
+func (snap *snapshot) Has(key []byte) (bool, error) {
+ snap.lock.RLock()
+ defer snap.lock.RUnlock()
+
+ if snap.db == nil {
+ return false, errSnapshotReleased
+ }
+
+ _, ok := snap.db[string(key)]
+
+ return ok, nil
+}
+
+// Get retrieves the given key if it's present in the snapshot backing by
+// key-value data store.
+func (snap *snapshot) Get(key []byte) ([]byte, error) {
+ snap.lock.RLock()
+ defer snap.lock.RUnlock()
+
+ if snap.db == nil {
+ return nil, errSnapshotReleased
+ }
+
+ if entry, ok := snap.db[string(key)]; ok {
+ return types.CopyBytes(entry), nil
+ }
+
+ return nil, errMemorydbNotFound
+}
+
+// Release releases associated resources. Release should always succeed and can
+// be called multiple times without causing error.
+func (snap *snapshot) Release() {
+ snap.lock.Lock()
+ defer snap.lock.Unlock()
+
+ snap.db = nil
+}
diff --git a/helper/kvdb/memorydb/memorydb_test.go b/helper/kvdb/memorydb/memorydb_test.go
new file mode 100644
index 0000000000..4d4a94995d
--- /dev/null
+++ b/helper/kvdb/memorydb/memorydb_test.go
@@ -0,0 +1,32 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package memorydb
+
+import (
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/dbtest"
+)
+
+func TestMemoryDB(t *testing.T) {
+ t.Run("DatabaseSuite", func(t *testing.T) {
+ dbtest.TestDatabaseSuite(t, func() kvdb.KVBatchStorage {
+ return New()
+ })
+ })
+}
diff --git a/helper/kvdb/snapshot.go b/helper/kvdb/snapshot.go
new file mode 100644
index 0000000000..e46d43ed6c
--- /dev/null
+++ b/helper/kvdb/snapshot.go
@@ -0,0 +1,25 @@
+package kvdb
+
+type Snapshot interface {
+ // Has retrieves if a key is present in the snapshot backing by a key-value
+ // data store.
+ Has(key []byte) (bool, error)
+
+ // Get retrieves the given key if it's present in the snapshot backing by
+ // key-value data store.
+ Get(key []byte) ([]byte, error)
+
+ // Release releases associated resources. Release should always succeed and can
+ // be called multiple times without causing error.
+ Release()
+}
+
+// Snapshotter wraps the Snapshot method of a backing data store.
+type Snapshotter interface {
+ // NewSnapshot creates a database snapshot based on the current state.
+ // The created snapshot will not be affected by all following mutations
+ // happened on the database.
+ // Note don't forget to release the snapshot once it's used up, otherwise
+ // the stale data will never be cleaned up by the underlying compactor.
+ NewSnapshot() (Snapshot, error)
+}
diff --git a/helper/metrics/helper.go b/helper/metrics/helper.go
new file mode 100644
index 0000000000..55866b350b
--- /dev/null
+++ b/helper/metrics/helper.go
@@ -0,0 +1,59 @@
+package metrics
+
+import (
+ "strings"
+
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+// helper function
+
+func ParseLables(labelsWithValues ...string) prometheus.Labels {
+ constLabels := map[string]string{}
+
+ if len(labelsWithValues)%2 == 0 {
+ for i := 1; i < len(labelsWithValues); i += 2 {
+ constLabels[labelsWithValues[i-1]] = labelsWithValues[i]
+ }
+ } else {
+ panic("invalid labels")
+ }
+
+ return constLabels
+}
+
+func CounterInc(counter prometheus.Counter) {
+ if counter == nil {
+ return
+ }
+
+ counter.Inc()
+}
+
+func AddCounter(counter prometheus.Counter, v float64) {
+ if counter == nil {
+ return
+ }
+
+ counter.Add(v)
+}
+
+func SetGauge(gauge prometheus.Gauge, v float64) {
+ if gauge == nil {
+ return
+ }
+
+ gauge.Set(v)
+}
+
+func HistogramObserve(histogram prometheus.Histogram, v float64) {
+ if histogram == nil {
+ return
+ }
+
+ histogram.Observe(v)
+}
+
+func MetricName2Help(name string) string {
+ return strings.ReplaceAll(name, "_", " ")
+}
diff --git a/helper/metrics/metrics.go b/helper/metrics/metrics.go
index 1d070635ef..5fa79b7ab6 100644
--- a/helper/metrics/metrics.go
+++ b/helper/metrics/metrics.go
@@ -1,43 +1,169 @@
package metrics
-import "github.com/prometheus/client_golang/prometheus"
+import (
+ "sync/atomic"
+ "time"
-// helper function
+ "github.com/prometheus/client_golang/prometheus"
+)
-func ParseLables(labelsWithValues ...string) prometheus.Labels {
- constLabels := map[string]string{}
+type DurationUnit int
- if len(labelsWithValues)%2 == 0 {
- for i := 1; i < len(labelsWithValues); i += 2 {
- constLabels[labelsWithValues[i-1]] = labelsWithValues[i]
- }
+const (
+ DurationSecond DurationUnit = iota
+ DurationMillisecond
+ DurationMicrosecond
+ DurationNanosecond
+)
+
+// DurationContext is use for accumulating duration
+type DurationContext interface {
+ SetUnit(DurationUnit) // set time counting units
+ Start() // start timer, exclusive with Add
+ Add(time.Duration) // increase duration, exclusive with Start, should be thread safe
+ Duration() float64 // sum up duration on base unit
+}
+
+// CounterContext is use for accumulating value (count or size)
+type CounterContext interface {
+ Add(int64) // add value count(or size), should be thread safe
+ Count() int64 // number of times
+ Value() int64 // sum up value
+}
+
+// CumulativeDurationFn represents duration context handling for histogram
+// metric.
+type CumulativeDurationFn func(DurationContext)
+
+// DurationHistogram is histogram used for duration.
+type DurationHistogram interface {
+ TimeAccumulator() CumulativeDurationFn
+}
+
+// CountTotalFn is used to count total the context holding.
+type CountTotalFn func(CounterContext)
+
+// TotalCountHistogram is histogram used for total count (or size).
+type TotalCountHistogram interface {
+ CountAccumulator() CountTotalFn
+}
+
+// standard duration context
+type standardDurationContext struct {
+ unit DurationUnit
+ starttime time.Time
+ tmp int64
+}
+
+func NewDurationContextWithUnit(u DurationUnit) DurationContext {
+ return &standardDurationContext{unit: u}
+}
+func (ctx *standardDurationContext) SetUnit(u DurationUnit) { ctx.unit = u }
+func (ctx *standardDurationContext) Start() { ctx.starttime = time.Now() }
+func (ctx *standardDurationContext) Add(d time.Duration) { atomic.AddInt64(&ctx.tmp, int64(d)) }
+func (ctx *standardDurationContext) Duration() float64 {
+ var d time.Duration
+
+ if ctx.tmp > 0 {
+ d = time.Duration(ctx.tmp)
+ } else if ctx.starttime.IsZero() {
+ // donot start, and no duration cumulatived
+ return 0
} else {
- panic("invalid labels")
+ d = time.Since(ctx.starttime)
+ }
+
+ switch ctx.unit {
+ case DurationSecond:
+ return d.Seconds()
+ case DurationMillisecond:
+ return float64(d.Milliseconds())
+ case DurationMicrosecond:
+ return float64(d.Microseconds())
+ case DurationNanosecond:
+ return float64(d.Nanoseconds())
}
- return constLabels
+ return 0
}
-func CounterInc(counter prometheus.Counter) {
- if counter == nil {
- return
- }
+// nil duration context
+type nilDurationContext struct{}
+
+func NilDurationContext() DurationContext { return &nilDurationContext{} }
+func (ctx *nilDurationContext) SetUnit(u DurationUnit) {}
+func (ctx *nilDurationContext) Start() {}
+func (ctx *nilDurationContext) Add(d time.Duration) {}
+func (ctx *nilDurationContext) Duration() float64 { return 0 }
- counter.Inc()
+// standard counter context
+type standardCounterContext struct {
+ val int64
+ count int64
}
-func SetGauge(gauge prometheus.Gauge, v float64) {
- if gauge == nil {
- return
- }
+func NewCounterContext() CounterContext { return &standardCounterContext{} }
+func (ctx *standardCounterContext) Add(v int64) {
+ atomic.AddInt64(&ctx.val, v)
+ atomic.AddInt64(&ctx.count, 1)
+}
+func (ctx *standardCounterContext) Count() int64 { return ctx.count }
+func (ctx *standardCounterContext) Value() int64 { return ctx.val }
+
+// nil counter context
+type nilCounterContext struct{}
+
+func NilCounterContext() CounterContext { return &nilCounterContext{} }
+func (ctx *nilCounterContext) Add(v int64) {}
+func (ctx *nilCounterContext) Count() int64 { return 0 }
+func (ctx *nilCounterContext) Value() int64 { return 0 }
+
+// standard duration histogram metric
+type standardDurationHistogramMetric struct {
+ metric prometheus.Histogram
+}
- gauge.Set(v)
+func NewHistogramDurationMetric(metric prometheus.Histogram) DurationHistogram {
+ return &standardDurationHistogramMetric{metric: metric}
}
+func (m *standardDurationHistogramMetric) TimeAccumulator() CumulativeDurationFn {
+ return func(ctx DurationContext) {
+ if m.metric == nil {
+ return
+ }
-func HistogramObserve(histogram prometheus.Histogram, v float64) {
- if histogram == nil {
- return
+ m.metric.Observe(ctx.Duration())
}
+}
+
+// nil duration histogram metric
+type nilDurationHistogramMetric struct{}
+
+func NilHistogramDurationMetric() DurationHistogram {
+ return &nilDurationHistogramMetric{}
+}
+func (m *nilDurationHistogramMetric) TimeAccumulator() CumulativeDurationFn {
+ return func(dc DurationContext) {}
+}
+
+type standardTotalCountHistogram struct {
+ metric prometheus.Histogram
+}
+
+func NewTotalCounterHistogram(metric prometheus.Histogram) TotalCountHistogram {
+ return &standardTotalCountHistogram{metric: metric}
+}
+func (m *standardTotalCountHistogram) CountAccumulator() CountTotalFn {
+ return func(ctx CounterContext) {
+ if m.metric == nil {
+ return
+ }
- histogram.Observe(v)
+ m.metric.Observe(float64(ctx.Value()))
+ }
}
+
+type nilTotalCountHistogram struct{}
+
+func NilTotalCounterHistogram() TotalCountHistogram { return &nilTotalCountHistogram{} }
+func (m *nilTotalCountHistogram) CountAccumulator() CountTotalFn { return func(CounterContext) {} }
diff --git a/helper/rawdb/accessors_chain.go b/helper/rawdb/accessors_chain.go
new file mode 100644
index 0000000000..abe6c3c105
--- /dev/null
+++ b/helper/rawdb/accessors_chain.go
@@ -0,0 +1,134 @@
+package rawdb
+
+import (
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/dogechain-lab/fastrlp"
+)
+
+func ReadBody(db kvdb.KVReader, hash types.Hash) (*types.Body, error) {
+ body := new(types.Body)
+ err := readRLP(db, bodyKey(hash), body)
+
+ return body, err
+}
+
+func WriteBody(db kvdb.KVWriter, hash types.Hash, body *types.Body) error {
+ return writeRLP(db, bodyKey(hash), body)
+}
+
+func ReadCanonicalHash(db kvdb.KVReader, number uint64) (types.Hash, bool) {
+ data, ok, err := db.Get(canonicalHashKey(number))
+ if err != nil || !ok {
+ return types.Hash{}, false
+ }
+
+ return types.BytesToHash(data), true
+}
+
+func WriteCanonicalHash(db kvdb.KVWriter, number uint64, hash types.Hash) error {
+ return db.Set(canonicalHashKey(number), hash.Bytes())
+}
+
+func ReadTotalDifficulty(db kvdb.KVReader, hash types.Hash) (*big.Int, bool) {
+ data, ok, err := db.Get(difficultyKey(hash))
+ if err != nil || !ok {
+ return nil, false
+ }
+
+ return new(big.Int).SetBytes(data), true
+}
+
+func WriteTotalDifficulty(db kvdb.KVWriter, hash types.Hash, diff *big.Int) error {
+ return db.Set(difficultyKey(hash), diff.Bytes())
+}
+
+func ReadHeader(db kvdb.KVReader, hash types.Hash) (*types.Header, error) {
+ header := new(types.Header)
+ err := readRLP(db, headerKey(hash), header)
+
+ return header, err
+}
+
+func WriteHeader(db kvdb.KVWriter, hash types.Hash, header *types.Header) error {
+ return writeRLP(db, headerKey(hash), header)
+}
+
+func ReadReceipts(db kvdb.KVReader, hash types.Hash) ([]*types.Receipt, error) {
+ receipts := &types.Receipts{}
+ err := readRLP(db, receiptsKey(hash), receipts)
+
+ return *receipts, err
+}
+
+func WriteReceipts(db kvdb.KVWriter, hash types.Hash, receipts []*types.Receipt) error {
+ v := types.Receipts(receipts)
+
+ return writeRLP(db, receiptsKey(hash), &v)
+}
+
+func ReadTxLookup(db kvdb.KVReader, hash types.Hash) (types.Hash, bool) {
+ v, err := readRLP2(db, txLookupKey(hash))
+ if err != nil {
+ return types.Hash{}, false
+ }
+
+ blockHash, err := v.GetBytes(nil, 32)
+ if err != nil {
+ return types.Hash{}, false
+ }
+
+ return types.BytesToHash(blockHash), true
+}
+
+func WriteTxLookup(db kvdb.KVWriter, hash types.Hash, blockHash types.Hash) error {
+ var ar fastrlp.Arena
+ v := ar.NewBytes(blockHash.Bytes())
+
+ return writeRLP2(db, txLookupKey(hash), v)
+}
+
+func ReadHeadHash(db kvdb.KVReader) (types.Hash, bool) {
+ data, ok, err := db.Get(headHashKey)
+ if err != nil || !ok {
+ return types.Hash{}, false
+ }
+
+ return types.BytesToHash(data), true
+}
+
+func WriteHeadHash(db kvdb.KVWriter, hash types.Hash) error {
+ return db.Set(headHashKey, hash.Bytes())
+}
+
+func ReadHeadNumber(db kvdb.KVReader) (uint64, bool) {
+ data, ok, err := db.Get(headNumberKey)
+ if err != nil || !ok {
+ return 0, false
+ }
+
+ if len(data) != 8 {
+ return 0, false
+ }
+
+ return decodeUint(data), true
+}
+
+func WriteHeadNumber(db kvdb.KVWriter, number uint64) error {
+ return db.Set(headNumberKey, encodeUint(number))
+}
+
+func ReadForks(db kvdb.KVReader) ([]types.Hash, error) {
+ forks := &Forks{}
+ err := readRLP(db, forkEmptyKey, forks)
+
+ return *forks, err
+}
+
+func WriteForks(db kvdb.KVWriter, forks []types.Hash) error {
+ ff := Forks(forks)
+
+ return writeRLP(db, forkEmptyKey, &ff)
+}
diff --git a/helper/rawdb/accessors_snapshot.go b/helper/rawdb/accessors_snapshot.go
new file mode 100644
index 0000000000..62593ad697
--- /dev/null
+++ b/helper/rawdb/accessors_snapshot.go
@@ -0,0 +1,229 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "encoding/binary"
+ "log"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// ReadSnapshotDisabled retrieves if the snapshot maintenance is disabled.
+func ReadSnapshotDisabled(db kvdb.KVReader) bool {
+ disabled, _ := db.Has(snapshotDisabledKey)
+
+ return disabled
+}
+
+// WriteSnapshotDisabled stores the snapshot pause flag.
+func WriteSnapshotDisabled(db kvdb.KVWriter) {
+ if err := db.Set(snapshotDisabledKey, []byte("42")); err != nil {
+ logCrit("Failed to store snapshot disabled flag", "err", err)
+ }
+}
+
+// DeleteSnapshotDisabled deletes the flag keeping the snapshot maintenance disabled.
+func DeleteSnapshotDisabled(db kvdb.KVWriter) {
+ if err := db.Delete(snapshotDisabledKey); err != nil {
+ logCrit("Failed to remove snapshot disabled flag", "err", err)
+ }
+}
+
+// ReadSnapshotRoot retrieves the root of the block whose state is contained in
+// the persisted snapshot.
+func ReadSnapshotRoot(db kvdb.KVReader) types.Hash {
+ data, _, _ := db.Get(snapshotRootKey)
+ if len(data) != types.HashLength {
+ return types.Hash{}
+ }
+
+ return types.BytesToHash(data)
+}
+
+// WriteSnapshotRoot stores the root of the block whose state is contained in
+// the persisted snapshot.
+func WriteSnapshotRoot(db kvdb.KVWriter, root types.Hash) {
+ if err := db.Set(snapshotRootKey, root[:]); err != nil {
+ logCrit("Failed to store snapshot root", "err", err)
+ }
+}
+
+// DeleteSnapshotRoot deletes the hash of the block whose state is contained in
+// the persisted snapshot. Since snapshots are not immutable, this method can
+// be used during updates, so a crash or failure will mark the entire snapshot
+// invalid.
+func DeleteSnapshotRoot(db kvdb.KVWriter) {
+ if err := db.Delete(snapshotRootKey); err != nil {
+ logCrit("Failed to remove snapshot root", "err", err)
+ }
+}
+
+// ReadAccountSnapshot retrieves the snapshot entry of an account trie leaf.
+func ReadAccountSnapshot(db kvdb.KVReader, hash types.Hash) []byte {
+ data, _, _ := db.Get(snapshotAccountKey(hash))
+
+ return data
+}
+
+// WriteAccountSnapshot stores the snapshot entry of an account trie leaf.
+func WriteAccountSnapshot(db kvdb.KVWriter, hash types.Hash, entry []byte) {
+ if err := db.Set(snapshotAccountKey(hash), entry); err != nil {
+ logCrit("Failed to store account snapshot", "err", err)
+ }
+}
+
+// DeleteAccountSnapshot removes the snapshot entry of an account trie leaf.
+func DeleteAccountSnapshot(db kvdb.KVWriter, hash types.Hash) {
+ if err := db.Delete(snapshotAccountKey(hash)); err != nil {
+ logCrit("Failed to delete account snapshot", "err", err)
+ }
+}
+
+// ReadStorageSnapshot retrieves the snapshot entry of an storage trie leaf.
+func ReadStorageSnapshot(db kvdb.KVReader, accountHash, storageHash types.Hash) []byte {
+ data, _, _ := db.Get(snapshotStorageKey(accountHash, storageHash))
+
+ return data
+}
+
+// WriteStorageSnapshot stores the snapshot entry of an storage trie leaf.
+func WriteStorageSnapshot(db kvdb.KVWriter, accountHash, storageHash types.Hash, entry []byte) {
+ if err := db.Set(snapshotStorageKey(accountHash, storageHash), entry); err != nil {
+ logCrit("Failed to store storage snapshot", "err", err)
+ }
+}
+
+// DeleteStorageSnapshot removes the snapshot entry of an storage trie leaf.
+func DeleteStorageSnapshot(db kvdb.KVWriter, accountHash, storageHash types.Hash) {
+ if err := db.Delete(snapshotStorageKey(accountHash, storageHash)); err != nil {
+ logCrit("Failed to delete storage snapshot", "err", err)
+ }
+}
+
+// IterateStorageSnapshots returns an iterator for walking the entire storage
+// space of a specific account.
+func IterateStorageSnapshots(db kvdb.Iteratee, accountHash types.Hash) kvdb.Iterator {
+ return NewKeyLengthIterator(
+ db.NewIterator(SnapshotsStorageKey(accountHash), nil),
+ SnapshotPrefixLength+2*types.HashLength,
+ )
+}
+
+// ReadSnapshotJournal retrieves the serialized in-memory diff layers saved at
+// the last shutdown. The blob is expected to be max a few 10s of megabytes.
+func ReadSnapshotJournal(db kvdb.KVReader) []byte {
+ data, _, _ := db.Get(snapshotJournalKey)
+
+ return data
+}
+
+// WriteSnapshotJournal stores the serialized in-memory diff layers to save at
+// shutdown. The blob is expected to be max a few 10s of megabytes.
+func WriteSnapshotJournal(db kvdb.KVWriter, journal []byte) {
+ if err := db.Set(snapshotJournalKey, journal); err != nil {
+ logCrit("Failed to store snapshot journal", "err", err)
+ }
+}
+
+// DeleteSnapshotJournal deletes the serialized in-memory diff layers saved at
+// the last shutdown
+func DeleteSnapshotJournal(db kvdb.KVWriter) {
+ if err := db.Delete(snapshotJournalKey); err != nil {
+ logCrit("Failed to remove snapshot journal", "err", err)
+ }
+}
+
+// ReadSnapshotGenerator retrieves the serialized snapshot generator saved at
+// the last shutdown.
+func ReadSnapshotGenerator(db kvdb.KVReader) []byte {
+ data, _, _ := db.Get(snapshotGeneratorKey)
+
+ return data
+}
+
+// WriteSnapshotGenerator stores the serialized snapshot generator to save at
+// shutdown.
+func WriteSnapshotGenerator(db kvdb.KVWriter, generator []byte) {
+ if err := db.Set(snapshotGeneratorKey, generator); err != nil {
+ logCrit("Failed to store snapshot generator", "err", err)
+ }
+}
+
+// DeleteSnapshotGenerator deletes the serialized snapshot generator saved at
+// the last shutdown
+func DeleteSnapshotGenerator(db kvdb.KVWriter) {
+ if err := db.Delete(snapshotGeneratorKey); err != nil {
+ logCrit("Failed to remove snapshot generator", "err", err)
+ }
+}
+
+// ReadSnapshotRecoveryNumber retrieves the block number of the last persisted
+// snapshot layer.
+func ReadSnapshotRecoveryNumber(db kvdb.KVReader) *uint64 {
+ data, _, _ := db.Get(snapshotRecoveryKey)
+ if len(data) == 0 {
+ return nil
+ }
+
+ if len(data) != 8 {
+ return nil
+ }
+
+ number := binary.BigEndian.Uint64(data)
+
+ return &number
+}
+
+// WriteSnapshotRecoveryNumber stores the block number of the last persisted
+// snapshot layer.
+func WriteSnapshotRecoveryNumber(db kvdb.KVWriter, number uint64) {
+ var buf [8]byte
+
+ binary.BigEndian.PutUint64(buf[:], number)
+
+ if err := db.Set(snapshotRecoveryKey, buf[:]); err != nil {
+ logCrit("Failed to store snapshot recovery number", "err", err)
+ }
+}
+
+// DeleteSnapshotRecoveryNumber deletes the block number of the last persisted
+// snapshot layer.
+func DeleteSnapshotRecoveryNumber(db kvdb.KVWriter) {
+ if err := db.Delete(snapshotRecoveryKey); err != nil {
+ logCrit("Failed to remove snapshot recovery number", "err", err)
+ }
+}
+
+// ReadSnapshotSyncStatus retrieves the serialized sync status saved at shutdown.
+func ReadSnapshotSyncStatus(db kvdb.KVReader) []byte {
+ data, _, _ := db.Get(snapshotSyncStatusKey)
+
+ return data
+}
+
+// WriteSnapshotSyncStatus stores the serialized sync status to save at shutdown.
+func WriteSnapshotSyncStatus(db kvdb.KVWriter, status []byte) {
+ if err := db.Set(snapshotSyncStatusKey, status); err != nil {
+ logCrit("Failed to store snapshot sync status", "err", err)
+ }
+}
+
+func logCrit(msg string, args ...interface{}) {
+ log.Fatal(msg, args)
+}
diff --git a/helper/rawdb/accessors_state.go b/helper/rawdb/accessors_state.go
new file mode 100644
index 0000000000..85fdc1c411
--- /dev/null
+++ b/helper/rawdb/accessors_state.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// ReadCode retrieves the contract code of the provided code hash.
+func ReadCode(db kvdb.KVReader, hash types.Hash) []byte {
+ // Try with the prefixed code scheme first, if not then try with legacy
+ // scheme.
+ data := ReadCodeWithPrefix(db, hash)
+ if len(data) != 0 {
+ return data
+ }
+
+ data, _, _ = db.Get(hash.Bytes())
+
+ return data
+}
+
+// ReadCodeWithPrefix retrieves the contract code of the provided code hash.
+// The main difference between this function and ReadCode is this function
+// will only check the existence with latest scheme(with prefix).
+func ReadCodeWithPrefix(db kvdb.KVReader, hash types.Hash) []byte {
+ data, _, _ := db.Get(CodeKey(hash))
+
+ return data
+}
+
+// ReadTrieNode retrieves the trie node of the provided hash.
+func ReadTrieNode(db kvdb.KVReader, hash types.Hash) []byte {
+ data, _, _ := db.Get(hash.Bytes())
+
+ return data
+}
+
+// HasCode checks if the contract code corresponding to the
+// provided code hash is present in the db.
+func HasCode(db kvdb.KVReader, hash types.Hash) bool {
+ // Try with the prefixed code scheme first, if not then try with legacy
+ // scheme.
+ if ok := HasCodeWithPrefix(db, hash); ok {
+ return true
+ }
+
+ ok, _ := db.Has(hash.Bytes())
+
+ return ok
+}
+
+// HasCodeWithPrefix checks if the contract code corresponding to the
+// provided code hash is present in the db. This function will only check
+// presence using the prefix-scheme.
+func HasCodeWithPrefix(db kvdb.KVReader, hash types.Hash) bool {
+ ok, _ := db.Has(CodeKey(hash))
+
+ return ok
+}
+
+// HasTrieNode checks if the trie node with the provided hash is present in db.
+func HasTrieNode(db kvdb.KVReader, hash types.Hash) bool {
+ ok, _ := db.Has(hash.Bytes())
+
+ return ok
+}
+
+// WriteCode writes the provided contract code database.
+func WriteCode(db kvdb.KVWriter, hash types.Hash, code []byte) error {
+ return db.Set(CodeKey(hash), code)
+}
+
+// WriteTrieNode writes the provided trie node database.
+func WriteTrieNode(db kvdb.KVWriter, hash types.Hash, node []byte) error {
+ return db.Set(hash.Bytes(), node)
+}
+
+// DeleteCode deletes the specified contract code from the database.
+func DeleteCode(db kvdb.KVWriter, hash types.Hash) error {
+ return db.Delete(CodeKey(hash))
+}
+
+// DeleteTrieNode deletes the specified trie node from the database.
+func DeleteTrieNode(db kvdb.KVWriter, hash types.Hash) error {
+ return db.Delete(hash.Bytes())
+}
diff --git a/helper/rawdb/database.go b/helper/rawdb/database.go
new file mode 100644
index 0000000000..90c609fb4e
--- /dev/null
+++ b/helper/rawdb/database.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import (
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+)
+
+// nofreezedb is a database wrapper that disables freezer data retrievals.
+type nofreezedb struct {
+ kvdb.KVBatchStorage
+}
+
+// NewDatabase creates a high level database on top of a given key-value data
+// store without a freezer moving immutable chain segments into cold storage.
+func NewDatabase(db kvdb.KVBatchStorage) kvdb.Database {
+ return &nofreezedb{KVBatchStorage: db}
+}
+
+// NewMemoryDatabase creates an ephemeral in-memory key-value database without a
+// freezer moving immutable chain segments into cold storage.
+func NewMemoryDatabase() kvdb.Database {
+ return NewDatabase(memorydb.New())
+}
+
+// NewMemoryDatabaseWithCap creates an ephemeral in-memory key-value database
+// with an initial starting capacity, but without a freezer moving immutable
+// chain segments into cold storage.
+func NewMemoryDatabaseWithCap(size int) kvdb.Database {
+ return NewDatabase(memorydb.NewWithCap(size))
+}
+
+// // NewLevelDBDatabase creates a persistent key-value database without a freezer
+// // moving immutable chain segments into cold storage.
+// func NewLevelDBDatabase(
+// file string,
+// cache int,
+// handles int,
+// logger hclog.Logger,
+// namespace string,
+// readonly bool,
+// ) (kvdb.Database, error) {
+// if logger == nil {
+// logger = hclog.NewNullLogger()
+// }
+
+// db, err := leveldb.New(file,
+// leveldb.SetBloomKeyBits(),
+// )
+// if err != nil {
+// return nil, err
+// }
+// return NewDatabase(db), nil
+// }
+
+// // NewLevelDBDatabaseWithFreezer creates a persistent key-value database with a
+// // freezer moving immutable chain segments into cold storage. The passed ancient
+// // indicates the path of root ancient directory where the chain freezer can be
+// // opened.
+// func NewLevelDBDatabaseWithFreezer(
+// file string,
+// cache int,
+// handles int,
+// ancient string,
+// namespace string,
+// readonly bool,
+// ) (kvdb.Database, error) {
+// kvdb, err := leveldb.New(file, cache, handles, namespace, readonly)
+// if err != nil {
+// return nil, err
+// }
+// frdb, err := NewDatabaseWithFreezer(kvdb, ancient, namespace, readonly)
+// if err != nil {
+// kvdb.Close()
+// return nil, err
+// }
+// return frdb, nil
+// }
diff --git a/blockchain/storage/utils.go b/helper/rawdb/forks.go
similarity index 98%
rename from blockchain/storage/utils.go
rename to helper/rawdb/forks.go
index 6e33bc6fc0..0676461431 100644
--- a/blockchain/storage/utils.go
+++ b/helper/rawdb/forks.go
@@ -1,4 +1,4 @@
-package storage
+package rawdb
import (
"github.com/dogechain-lab/dogechain/types"
diff --git a/helper/rawdb/key_length_iterator.go b/helper/rawdb/key_length_iterator.go
new file mode 100644
index 0000000000..d489555c2e
--- /dev/null
+++ b/helper/rawdb/key_length_iterator.go
@@ -0,0 +1,47 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rawdb
+
+import "github.com/dogechain-lab/dogechain/helper/kvdb"
+
+// KeyLengthIterator is a wrapper for a database iterator that ensures only key-value pairs
+// with a specific key length will be returned.
+type KeyLengthIterator struct {
+ requiredKeyLength int
+ kvdb.Iterator
+}
+
+// NewKeyLengthIterator returns a wrapped version of the iterator that will only return key-value
+// pairs where keys with a specific key length will be returned.
+func NewKeyLengthIterator(it kvdb.Iterator, keyLen int) kvdb.Iterator {
+ return &KeyLengthIterator{
+ Iterator: it,
+ requiredKeyLength: keyLen,
+ }
+}
+
+func (it *KeyLengthIterator) Next() bool {
+ // Return true as soon as a key with the required key length is discovered
+ for it.Iterator.Next() {
+ if len(it.Iterator.Key()) == it.requiredKeyLength {
+ return true
+ }
+ }
+
+ // Return false when we exhaust the keys in the underlying iterator.
+ return false
+}
diff --git a/helper/rawdb/rlp.go b/helper/rawdb/rlp.go
new file mode 100644
index 0000000000..fa375446b7
--- /dev/null
+++ b/helper/rawdb/rlp.go
@@ -0,0 +1,56 @@
+package rawdb
+
+import (
+ "errors"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/dogechain-lab/fastrlp"
+)
+
+var ErrNotFound = errors.New("not found")
+
+func readRLP(db kvdb.KVReader, key []byte, raw types.RLPUnmarshaler) error {
+ data, ok, err := db.Get(key)
+ if err != nil {
+ return err
+ } else if !ok {
+ return ErrNotFound
+ }
+
+ if obj, ok := raw.(types.RLPStoreUnmarshaler); ok {
+ // decode in the store format
+ return obj.UnmarshalStoreRLP(data)
+ }
+
+ // normal rlp decoding
+ return raw.UnmarshalRLP(data)
+}
+
+func writeRLP(db kvdb.KVWriter, key []byte, raw types.RLPMarshaler) error {
+ var data []byte
+ if obj, ok := raw.(types.RLPStoreMarshaler); ok {
+ data = obj.MarshalStoreRLPTo(nil)
+ } else {
+ data = raw.MarshalRLPTo(nil)
+ }
+
+ return db.Set(key, data)
+}
+
+func readRLP2(db kvdb.Reader, key []byte) (*fastrlp.Value, error) {
+ data, ok, err := db.Get(key)
+ if err != nil {
+ return nil, err
+ } else if !ok {
+ return nil, ErrNotFound
+ }
+
+ return types.RlpUnmarshal(data)
+}
+
+func writeRLP2(db kvdb.KVWriter, key []byte, v *fastrlp.Value) error {
+ dst := v.MarshalTo(nil)
+
+ return db.Set(key, dst)
+}
diff --git a/helper/rawdb/schema.go b/helper/rawdb/schema.go
new file mode 100644
index 0000000000..c8a44dc5aa
--- /dev/null
+++ b/helper/rawdb/schema.go
@@ -0,0 +1,120 @@
+package rawdb
+
+import (
+ "encoding/binary"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+const (
+ SnapshotPrefixLength = 2 // "s*" s means snapshot, * is the first letter of the subkey
+)
+
+// snapshot key prefix
+var (
+ // codePrefix is the code prefix for leveldb
+ codePrefix = []byte("code")
+ // SnapshotAccountPrefix + account hash -> account trie value
+ SnapshotAccountPrefix = []byte("sa")
+ // SnapshotStoragePrefix + account hash + storage hash -> storage trie value
+ SnapshotStoragePrefix = []byte("ss")
+)
+
+// snapshot key
+var (
+ // snapshotDisabledKey flags that the snapshot should not be maintained due to initial sync.
+ snapshotDisabledKey = []byte("SnapshotDisabled")
+ // snapshotRootKey tracks the hash of the last snapshot.
+ snapshotRootKey = []byte("SnapshotRoot")
+ // snapshotJournalKey tracks the in-memory diff layers across restarts.
+ snapshotJournalKey = []byte("SnapshotJournal")
+ // snapshotGeneratorKey tracks the snapshot generation marker across restarts.
+ snapshotGeneratorKey = []byte("SnapshotGenerator")
+ // snapshotRecoveryKey tracks the snapshot recovery marker across restarts.
+ snapshotRecoveryKey = []byte("SnapshotRecovery")
+ // snapshotSyncStatusKey tracks the snapshot sync status across restarts.
+ snapshotSyncStatusKey = []byte("SnapshotSyncStatus")
+ // skeletonSyncStatusKey tracks the skeleton sync status across restarts.
+ skeletonSyncStatusKey = []byte("SkeletonSyncStatus")
+)
+
+// blockchain key prefix
+var (
+ // bodyPrefix + header hash -> body
+ bodyPrefix = []byte("b")
+ // canonicalPrefix + block number (big endian uint64) -> canonical block(header) hash
+ canonicalPrefix = []byte("c")
+ // difficultyPrefix + header hash -> difficulty
+ difficultyPrefix = []byte("d")
+ // headerPrefix + header hash -> header
+ headerPrefix = []byte("h")
+ // receiptsPrefix + transaction hash -> receipt
+ receiptsPrefix = []byte("r")
+ // txLookupPrefix + transaction hash -> block hash
+ txLookupPrefix = []byte("l")
+)
+
+// blockchain keys
+var (
+ // headHashKey tracks the latest known header's hash
+ headHashKey = []byte("ohash")
+ // headNumberKey tracks the latest known header's number
+ headNumberKey = []byte("onumber")
+ // forkEmptyKey tracks any fork which never exists
+ forkEmptyKey = []byte("empty")
+)
+
+func encodeUint(n uint64) []byte {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint64(b[:], n)
+
+ return b[:]
+}
+
+func decodeUint(b []byte) uint64 {
+ return binary.BigEndian.Uint64(b[:])
+}
+
+// CodeKey = CodePrefix + hash
+func CodeKey(hash types.Hash) []byte {
+ return append(codePrefix, hash.Bytes()...)
+}
+
+// snapshotAccountKey = SnapshotAccountPrefix + hash
+func snapshotAccountKey(hash types.Hash) []byte {
+ return append(SnapshotAccountPrefix, hash.Bytes()...)
+}
+
+// snapshotStorageKey = SnapshotStoragePrefix + account hash + storage hash
+func snapshotStorageKey(accountHash, storageHash types.Hash) []byte {
+ return append(append(SnapshotStoragePrefix, accountHash.Bytes()...), storageHash.Bytes()...)
+}
+
+// SnapshotsStorageKey = SnapshotStoragePrefix + account hash (+ storage hash)
+func SnapshotsStorageKey(accountHash types.Hash) []byte {
+ return append(SnapshotStoragePrefix, accountHash.Bytes()...)
+}
+
+func bodyKey(h types.Hash) []byte {
+ return append(bodyPrefix, h.Bytes()...)
+}
+
+func canonicalHashKey(n uint64) []byte {
+ return append(canonicalPrefix, encodeUint(n)...)
+}
+
+func difficultyKey(h types.Hash) []byte {
+ return append(difficultyPrefix, h.Bytes()...)
+}
+
+func headerKey(h types.Hash) []byte {
+ return append(headerPrefix, h.Bytes()...)
+}
+
+func receiptsKey(h types.Hash) []byte {
+ return append(receiptsPrefix, h.Bytes()...)
+}
+
+func txLookupKey(h types.Hash) []byte {
+ return append(txLookupPrefix, h.Bytes()...)
+}
diff --git a/helper/rlp/decode.go b/helper/rlp/decode.go
new file mode 100644
index 0000000000..fa918bb02c
--- /dev/null
+++ b/helper/rlp/decode.go
@@ -0,0 +1,1257 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "reflect"
+ "strings"
+ "sync"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp/internal/rlpstruct"
+)
+
+//lint:ignore ST1012 EOL is not an error.
+
+// EOL is returned when the end of the current list
+// has been reached during streaming.
+//
+//nolint:stylecheck
+var EOL = errors.New("rlp: end of list")
+
+var (
+ ErrExpectedString = errors.New("rlp: expected String or Byte")
+ ErrExpectedList = errors.New("rlp: expected List")
+ ErrCanonInt = errors.New("rlp: non-canonical integer format")
+ ErrCanonSize = errors.New("rlp: non-canonical size information")
+ ErrElemTooLarge = errors.New("rlp: element is larger than containing list")
+ ErrValueTooLarge = errors.New("rlp: value size exceeds available input length")
+ ErrMoreThanOneValue = errors.New("rlp: input contains more than one value")
+
+ // internal errors
+ errNotInList = errors.New("rlp: call of ListEnd outside of any list")
+ errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL")
+ errUintOverflow = errors.New("rlp: uint overflow")
+ errNoPointer = errors.New("rlp: interface given to Decode must be a pointer")
+ errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil")
+
+ streamPool = sync.Pool{
+ New: func() interface{} { return new(Stream) },
+ }
+)
+
+// Decoder is implemented by types that require custom RLP decoding rules or need to decode
+// into private fields.
+//
+// The DecodeRLP method should read one value from the given Stream. It is not forbidden to
+// read less or more, but it might be confusing.
+type Decoder interface {
+ DecodeRLP(*Stream) error
+}
+
+// Decode parses RLP-encoded data from r and stores the result in the value pointed to by
+// val. Please see package-level documentation for the decoding rules. Val must be a
+// non-nil pointer.
+//
+// If r does not implement ByteReader, Decode will do its own buffering.
+//
+// Note that Decode does not set an input limit for all readers and may be vulnerable to
+// panics cause by huge value sizes. If you need an input limit, use
+//
+// NewStream(r, limit).Decode(val)
+func Decode(r io.Reader, val interface{}) error {
+ stream, _ := streamPool.Get().(*Stream)
+ defer streamPool.Put(stream)
+
+ stream.Reset(r, 0)
+
+ return stream.Decode(val)
+}
+
+// DecodeBytes parses RLP data from b into val. Please see package-level documentation for
+// the decoding rules. The input must contain exactly one value and no trailing data.
+func DecodeBytes(b []byte, val interface{}) error {
+ r := bytes.NewReader(b)
+
+ stream, _ := streamPool.Get().(*Stream)
+ defer streamPool.Put(stream)
+
+ stream.Reset(r, uint64(len(b)))
+
+ if err := stream.Decode(val); err != nil {
+ return err
+ }
+
+ if r.Len() > 0 {
+ return ErrMoreThanOneValue
+ }
+
+ return nil
+}
+
+type decodeError struct {
+ msg string
+ typ reflect.Type
+ ctx []string
+}
+
+func (err *decodeError) Error() string {
+ ctx := ""
+ if len(err.ctx) > 0 {
+ ctx = ", decoding into "
+ for i := len(err.ctx) - 1; i >= 0; i-- {
+ ctx += err.ctx[i]
+ }
+ }
+
+ return fmt.Sprintf("rlp: %s for %v%s", err.msg, err.typ, ctx)
+}
+
+func wrapStreamError(err error, typ reflect.Type) error {
+ switch {
+ case errors.Is(err, ErrCanonInt):
+ return &decodeError{msg: "non-canonical integer (leading zero bytes)", typ: typ}
+ case errors.Is(err, ErrCanonSize):
+ return &decodeError{msg: "non-canonical size information", typ: typ}
+ case errors.Is(err, ErrExpectedList):
+ return &decodeError{msg: "expected input list", typ: typ}
+ case errors.Is(err, ErrExpectedString):
+ return &decodeError{msg: "expected input string or byte", typ: typ}
+ case errors.Is(err, errUintOverflow):
+ return &decodeError{msg: "input string too long", typ: typ}
+ case errors.Is(err, errNotAtEOL):
+ return &decodeError{msg: "input list has too many elements", typ: typ}
+ }
+
+ return err
+}
+
+func addErrorContext(err error, ctx string) error {
+ var decErr *decodeError
+ if errors.As(err, &decErr) {
+ decErr.ctx = append(decErr.ctx, ctx)
+ }
+
+ return err
+}
+
+var (
+ decoderInterface = reflect.TypeOf(new(Decoder)).Elem()
+ bigInt = reflect.TypeOf(big.Int{})
+)
+
+func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) {
+ kind := typ.Kind()
+
+ switch {
+ case typ == rawValueType:
+ return decodeRawValue, nil
+ case typ.AssignableTo(reflect.PtrTo(bigInt)):
+ return decodeBigInt, nil
+ case typ.AssignableTo(bigInt):
+ return decodeBigIntNoPtr, nil
+ case kind == reflect.Ptr:
+ return makePtrDecoder(typ, tags)
+ case reflect.PtrTo(typ).Implements(decoderInterface):
+ return decodeDecoder, nil
+ case isUint(kind):
+ return decodeUint, nil
+ case kind == reflect.Bool:
+ return decodeBool, nil
+ case kind == reflect.String:
+ return decodeString, nil
+ case kind == reflect.Slice || kind == reflect.Array:
+ return makeListDecoder(typ, tags)
+ case kind == reflect.Struct:
+ return makeStructDecoder(typ)
+ case kind == reflect.Interface:
+ return decodeInterface, nil
+ default:
+ return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ)
+ }
+}
+
+func decodeRawValue(s *Stream, val reflect.Value) error {
+ r, err := s.Raw()
+ if err != nil {
+ return err
+ }
+
+ val.SetBytes(r)
+
+ return nil
+}
+
+func decodeUint(s *Stream, val reflect.Value) error {
+ typ := val.Type()
+
+ num, err := s.uint(typ.Bits())
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ val.SetUint(num)
+
+ return nil
+}
+
+func decodeBool(s *Stream, val reflect.Value) error {
+ b, err := s.Bool()
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ val.SetBool(b)
+
+ return nil
+}
+
+func decodeString(s *Stream, val reflect.Value) error {
+ b, err := s.Bytes()
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ val.SetString(string(b))
+
+ return nil
+}
+
+func decodeBigIntNoPtr(s *Stream, val reflect.Value) error {
+ return decodeBigInt(s, val.Addr())
+}
+
+func decodeBigInt(s *Stream, val reflect.Value) error {
+ i, _ := val.Interface().(*big.Int)
+ if i == nil {
+ i = new(big.Int)
+ val.Set(reflect.ValueOf(i))
+ }
+
+ err := s.decodeBigInt(i)
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ return nil
+}
+
+func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
+ etype := typ.Elem()
+ if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) {
+ if typ.Kind() == reflect.Array {
+ return decodeByteArray, nil
+ }
+
+ return decodeByteSlice, nil
+ }
+
+ etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
+ if etypeinfo.decoderErr != nil {
+ return nil, etypeinfo.decoderErr
+ }
+
+ var dec decoder
+
+ switch {
+ case typ.Kind() == reflect.Array:
+ dec = func(s *Stream, val reflect.Value) error {
+ return decodeListArray(s, val, etypeinfo.decoder)
+ }
+ case tag.Tail:
+ // A slice with "tail" tag can occur as the last field
+ // of a struct and is supposed to swallow all remaining
+ // list elements. The struct decoder already called s.List,
+ // proceed directly to decoding the elements.
+ dec = func(s *Stream, val reflect.Value) error {
+ return decodeSliceElems(s, val, etypeinfo.decoder)
+ }
+ default:
+ dec = func(s *Stream, val reflect.Value) error {
+ return decodeListSlice(s, val, etypeinfo.decoder)
+ }
+ }
+
+ return dec, nil
+}
+
+func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
+ size, err := s.List()
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ if size == 0 {
+ val.Set(reflect.MakeSlice(val.Type(), 0, 0))
+
+ return s.ListEnd()
+ }
+
+ if err := decodeSliceElems(s, val, elemdec); err != nil {
+ return err
+ }
+
+ return s.ListEnd()
+}
+
+func decodeSliceElems(s *Stream, val reflect.Value, elemdec decoder) error {
+ i := 0
+
+ for ; ; i++ {
+ // grow slice if necessary
+ if i >= val.Cap() {
+ newcap := val.Cap() + val.Cap()/2
+ if newcap < 4 {
+ newcap = 4
+ }
+
+ newv := reflect.MakeSlice(val.Type(), val.Len(), newcap)
+ reflect.Copy(newv, val)
+ val.Set(newv)
+ }
+
+ if i >= val.Len() {
+ val.SetLen(i + 1)
+ }
+ // decode into element
+ if err := elemdec(s, val.Index(i)); errors.Is(err, EOL) {
+ break
+ } else if err != nil {
+ return addErrorContext(err, fmt.Sprint("[", i, "]"))
+ }
+ }
+
+ if i < val.Len() {
+ val.SetLen(i)
+ }
+
+ return nil
+}
+
+func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
+ if _, err := s.List(); err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ vlen := val.Len()
+ i := 0
+
+ for ; i < vlen; i++ {
+ if err := elemdec(s, val.Index(i)); errors.Is(err, EOL) {
+ break
+ } else if err != nil {
+ return addErrorContext(err, fmt.Sprint("[", i, "]"))
+ }
+ }
+
+ if i < vlen {
+ return &decodeError{msg: "input list has too few elements", typ: val.Type()}
+ }
+
+ return wrapStreamError(s.ListEnd(), val.Type())
+}
+
+func decodeByteSlice(s *Stream, val reflect.Value) error {
+ b, err := s.Bytes()
+ if err != nil {
+ return wrapStreamError(err, val.Type())
+ }
+
+ val.SetBytes(b)
+
+ return nil
+}
+
+func decodeByteArray(s *Stream, val reflect.Value) error {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return err
+ }
+
+ slice := byteArrayBytes(val, val.Len())
+
+ switch kind {
+ case Byte:
+ if len(slice) == 0 {
+ return &decodeError{msg: "input string too long", typ: val.Type()}
+ } else if len(slice) > 1 {
+ return &decodeError{msg: "input string too short", typ: val.Type()}
+ }
+
+ slice[0] = s.byteval
+ s.kind = -1
+ case String:
+ if uint64(len(slice)) < size {
+ return &decodeError{msg: "input string too long", typ: val.Type()}
+ } else if uint64(len(slice)) > size {
+ return &decodeError{msg: "input string too short", typ: val.Type()}
+ }
+
+ if err := s.readFull(slice); err != nil {
+ return err
+ }
+ // Reject cases where single byte encoding should have been used.
+ if size == 1 && slice[0] < 128 {
+ return wrapStreamError(ErrCanonSize, val.Type())
+ }
+ case List:
+ return wrapStreamError(ErrExpectedString, val.Type())
+ }
+
+ return nil
+}
+
+func makeStructDecoder(typ reflect.Type) (decoder, error) {
+ fields, err := structFields(typ)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, f := range fields {
+ if f.info.decoderErr != nil {
+ return nil, structFieldError{typ, f.index, f.info.decoderErr}
+ }
+ }
+
+ dec := func(s *Stream, val reflect.Value) (err error) {
+ if _, err := s.List(); err != nil {
+ return wrapStreamError(err, typ)
+ }
+
+ for i, f := range fields {
+ err := f.info.decoder(s, val.Field(f.index))
+ if errors.Is(err, EOL) {
+ if f.optional {
+ // The field is optional, so reaching the end of the list before
+ // reaching the last field is acceptable. All remaining undecoded
+ // fields are zeroed.
+ zeroFields(val, fields[i:])
+
+ break
+ }
+
+ return &decodeError{msg: "too few elements", typ: typ}
+ } else if err != nil {
+ return addErrorContext(err, "."+typ.Field(f.index).Name)
+ }
+ }
+
+ return wrapStreamError(s.ListEnd(), typ)
+ }
+
+ return dec, nil
+}
+
+func zeroFields(structval reflect.Value, fields []field) {
+ for _, f := range fields {
+ fv := structval.Field(f.index)
+ fv.Set(reflect.Zero(fv.Type()))
+ }
+}
+
+// makePtrDecoder creates a decoder that decodes into the pointer's element type.
+func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) {
+ etype := typ.Elem()
+ etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{})
+
+ switch {
+ case etypeinfo.decoderErr != nil:
+ return nil, etypeinfo.decoderErr
+ case !tag.NilOK:
+ return makeSimplePtrDecoder(etype, etypeinfo), nil
+ default:
+ return makeNilPtrDecoder(etype, etypeinfo, tag), nil
+ }
+}
+
+func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder {
+ return func(s *Stream, val reflect.Value) (err error) {
+ newval := val
+ if val.IsNil() {
+ newval = reflect.New(etype)
+ }
+
+ if err = etypeinfo.decoder(s, newval.Elem()); err == nil {
+ val.Set(newval)
+ }
+
+ return err
+ }
+}
+
+// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty
+// values are decoded into a value of the element type, just like makePtrDecoder does.
+//
+// This decoder is used for pointer-typed struct fields with struct tag "nil".
+func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder {
+ typ := reflect.PtrTo(etype)
+ nilPtr := reflect.Zero(typ)
+
+ // Determine the value kind that results in nil pointer.
+ nilKind := typeNilKind(etype, ts)
+
+ return func(s *Stream, val reflect.Value) (err error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ val.Set(nilPtr)
+
+ return wrapStreamError(err, typ)
+ }
+ // Handle empty values as a nil pointer.
+ if kind != Byte && size == 0 {
+ if kind != nilKind {
+ return &decodeError{
+ msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind),
+ typ: typ,
+ }
+ }
+
+ // rearm s.Kind. This is important because the input
+ // position must advance to the next value even though
+ // we don't read anything.
+ s.kind = -1
+
+ val.Set(nilPtr)
+
+ return nil
+ }
+
+ newval := val
+
+ if val.IsNil() {
+ newval = reflect.New(etype)
+ }
+
+ if err = etypeinfo.decoder(s, newval.Elem()); err == nil {
+ val.Set(newval)
+ }
+
+ return err
+ }
+}
+
+var ifsliceType = reflect.TypeOf([]interface{}{})
+
+func decodeInterface(s *Stream, val reflect.Value) error {
+ if val.Type().NumMethod() != 0 {
+ return fmt.Errorf("rlp: type %v is not RLP-serializable", val.Type())
+ }
+
+ kind, _, err := s.Kind()
+ if err != nil {
+ return err
+ }
+
+ if kind == List {
+ slice := reflect.New(ifsliceType).Elem()
+ if err := decodeListSlice(s, slice, decodeInterface); err != nil {
+ return err
+ }
+
+ val.Set(slice)
+ } else {
+ b, err := s.Bytes()
+ if err != nil {
+ return err
+ }
+
+ val.Set(reflect.ValueOf(b))
+ }
+
+ return nil
+}
+
+func decodeDecoder(s *Stream, val reflect.Value) error {
+ //nolint:forcetypeassert
+ return val.Addr().Interface().(Decoder).DecodeRLP(s)
+}
+
+// Kind represents the kind of value contained in an RLP stream.
+type Kind int8
+
+const (
+ Byte Kind = iota
+ String
+ List
+)
+
+func (k Kind) String() string {
+ switch k {
+ case Byte:
+ return "Byte"
+ case String:
+ return "String"
+ case List:
+ return "List"
+ default:
+ return fmt.Sprintf("Unknown(%d)", k)
+ }
+}
+
+// ByteReader must be implemented by any input reader for a Stream. It
+// is implemented by e.g. bufio.Reader and bytes.Reader.
+type ByteReader interface {
+ io.Reader
+ io.ByteReader
+}
+
+// Stream can be used for piecemeal decoding of an input stream. This
+// is useful if the input is very large or if the decoding rules for a
+// type depend on the input structure. Stream does not keep an
+// internal buffer. After decoding a value, the input reader will be
+// positioned just before the type information for the next value.
+//
+// When decoding a list and the input position reaches the declared
+// length of the list, all operations will return error EOL.
+// The end of the list must be acknowledged using ListEnd to continue
+// reading the enclosing list.
+//
+// Stream is not safe for concurrent use.
+type Stream struct {
+ r ByteReader
+
+ remaining uint64 // number of bytes remaining to be read from r
+ size uint64 // size of value ahead
+ kinderr error // error from last readKind
+ stack []uint64 // list sizes
+ uintbuf [32]byte // auxiliary buffer for integer decoding
+ kind Kind // kind of value ahead
+ byteval byte // value of single byte in type tag
+ limited bool // true if input limit is in effect
+}
+
+// NewStream creates a new decoding stream reading from r.
+//
+// If r implements the ByteReader interface, Stream will
+// not introduce any buffering.
+//
+// For non-toplevel values, Stream returns ErrElemTooLarge
+// for values that do not fit into the enclosing list.
+//
+// Stream supports an optional input limit. If a limit is set, the
+// size of any toplevel value will be checked against the remaining
+// input length. Stream operations that encounter a value exceeding
+// the remaining input length will return ErrValueTooLarge. The limit
+// can be set by passing a non-zero value for inputLimit.
+//
+// If r is a bytes.Reader or strings.Reader, the input limit is set to
+// the length of r's underlying data unless an explicit limit is
+// provided.
+func NewStream(r io.Reader, inputLimit uint64) *Stream {
+ s := new(Stream)
+ s.Reset(r, inputLimit)
+
+ return s
+}
+
+// NewListStream creates a new stream that pretends to be positioned
+// at an encoded list of the given length.
+func NewListStream(r io.Reader, length uint64) *Stream {
+ s := new(Stream)
+ s.Reset(r, length)
+ s.kind = List
+ s.size = length
+
+ return s
+}
+
+// Bytes reads an RLP string and returns its contents as a byte slice.
+// If the input does not contain an RLP string, the returned
+// error will be ErrExpectedString.
+func (s *Stream) Bytes() ([]byte, error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return nil, err
+ }
+
+ switch kind {
+ case Byte:
+ s.kind = -1 // rearm Kind
+
+ return []byte{s.byteval}, nil
+ case String:
+ b := make([]byte, size)
+ if err = s.readFull(b); err != nil {
+ return nil, err
+ }
+
+ if size == 1 && b[0] < 128 {
+ return nil, ErrCanonSize
+ }
+
+ return b, nil
+ default:
+ return nil, ErrExpectedString
+ }
+}
+
+// ReadBytes decodes the next RLP value and stores the result in b.
+// The value size must match len(b) exactly.
+func (s *Stream) ReadBytes(b []byte) error {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return err
+ }
+
+ switch kind {
+ case Byte:
+ if len(b) != 1 {
+ return fmt.Errorf("input value has wrong size 1, want %d", len(b))
+ }
+
+ b[0] = s.byteval
+ s.kind = -1 // rearm Kind
+
+ return nil
+ case String:
+ if uint64(len(b)) != size {
+ return fmt.Errorf("input value has wrong size %d, want %d", size, len(b))
+ }
+
+ if err = s.readFull(b); err != nil {
+ return err
+ }
+
+ if size == 1 && b[0] < 128 {
+ return ErrCanonSize
+ }
+
+ return nil
+ default:
+ return ErrExpectedString
+ }
+}
+
+// Raw reads a raw encoded value including RLP type information.
+func (s *Stream) Raw() ([]byte, error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return nil, err
+ }
+
+ if kind == Byte {
+ s.kind = -1 // rearm Kind
+
+ return []byte{s.byteval}, nil
+ }
+
+ // The original header has already been read and is no longer
+ // available. Read content and put a new header in front of it.
+ start := headsize(size)
+ buf := make([]byte, uint64(start)+size)
+
+ if err := s.readFull(buf[start:]); err != nil {
+ return nil, err
+ }
+
+ if kind == String {
+ puthead(buf, 0x80, 0xB7, size)
+ } else {
+ puthead(buf, 0xC0, 0xF7, size)
+ }
+
+ return buf, nil
+}
+
+// Uint reads an RLP string of up to 8 bytes and returns its contents
+// as an unsigned integer. If the input does not contain an RLP string, the
+// returned error will be ErrExpectedString.
+//
+// Deprecated: use s.Uint64 instead.
+func (s *Stream) Uint() (uint64, error) {
+ return s.uint(64)
+}
+
+func (s *Stream) Uint64() (uint64, error) {
+ return s.uint(64)
+}
+
+func (s *Stream) Uint32() (uint32, error) {
+ i, err := s.uint(32)
+
+ return uint32(i), err
+}
+
+func (s *Stream) Uint16() (uint16, error) {
+ i, err := s.uint(16)
+
+ return uint16(i), err
+}
+
+func (s *Stream) Uint8() (uint8, error) {
+ i, err := s.uint(8)
+
+ return uint8(i), err
+}
+
+func (s *Stream) uint(maxbits int) (uint64, error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return 0, err
+ }
+
+ switch kind {
+ case Byte:
+ if s.byteval == 0 {
+ return 0, ErrCanonInt
+ }
+
+ s.kind = -1 // rearm Kind
+
+ return uint64(s.byteval), nil
+ case String:
+ if size > uint64(maxbits/8) {
+ return 0, errUintOverflow
+ }
+
+ v, err := s.readUint(byte(size))
+
+ switch {
+ case errors.Is(err, ErrCanonSize):
+ // Adjust error because we're not reading a size right now.
+ return 0, ErrCanonInt
+ case err != nil:
+ return 0, err
+ case size > 0 && v < 128:
+ return 0, ErrCanonSize
+ default:
+ return v, nil
+ }
+ default:
+ return 0, ErrExpectedString
+ }
+}
+
+// Bool reads an RLP string of up to 1 byte and returns its contents
+// as a boolean. If the input does not contain an RLP string, the
+// returned error will be ErrExpectedString.
+func (s *Stream) Bool() (bool, error) {
+ num, err := s.uint(8)
+ if err != nil {
+ return false, err
+ }
+
+ switch num {
+ case 0:
+ return false, nil
+ case 1:
+ return true, nil
+ default:
+ return false, fmt.Errorf("rlp: invalid boolean value: %d", num)
+ }
+}
+
+// List starts decoding an RLP list. If the input does not contain a
+// list, the returned error will be ErrExpectedList. When the list's
+// end has been reached, any Stream operation will return EOL.
+func (s *Stream) List() (size uint64, err error) {
+ kind, size, err := s.Kind()
+ if err != nil {
+ return 0, err
+ }
+
+ if kind != List {
+ return 0, ErrExpectedList
+ }
+
+ // Remove size of inner list from outer list before pushing the new size
+ // onto the stack. This ensures that the remaining outer list size will
+ // be correct after the matching call to ListEnd.
+ if inList, limit := s.listLimit(); inList {
+ s.stack[len(s.stack)-1] = limit - size
+ }
+
+ s.stack = append(s.stack, size)
+ s.kind = -1
+ s.size = 0
+
+ return size, nil
+}
+
+// ListEnd returns to the enclosing list.
+// The input reader must be positioned at the end of a list.
+func (s *Stream) ListEnd() error {
+ // Ensure that no more data is remaining in the current list.
+ if inList, listLimit := s.listLimit(); !inList {
+ return errNotInList
+ } else if listLimit > 0 {
+ return errNotAtEOL
+ }
+
+ s.stack = s.stack[:len(s.stack)-1] // pop
+ s.kind = -1
+ s.size = 0
+
+ return nil
+}
+
+// MoreDataInList reports whether the current list context contains
+// more data to be read.
+func (s *Stream) MoreDataInList() bool {
+ _, listLimit := s.listLimit()
+
+ return listLimit > 0
+}
+
+// BigInt decodes an arbitrary-size integer value.
+func (s *Stream) BigInt() (*big.Int, error) {
+ i := new(big.Int)
+ if err := s.decodeBigInt(i); err != nil {
+ return nil, err
+ }
+
+ return i, nil
+}
+
+func (s *Stream) decodeBigInt(dst *big.Int) error {
+ var buffer []byte
+
+ kind, size, err := s.Kind()
+
+ switch {
+ case err != nil:
+ return err
+ case kind == List:
+ return ErrExpectedString
+ case kind == Byte:
+ buffer = s.uintbuf[:1]
+ buffer[0] = s.byteval
+ s.kind = -1 // re-arm Kind
+ case size == 0:
+ // Avoid zero-length read.
+ s.kind = -1
+ case size <= uint64(len(s.uintbuf)):
+ // For integers smaller than s.uintbuf, allocating a buffer
+ // can be avoided.
+ buffer = s.uintbuf[:size]
+ if err := s.readFull(buffer); err != nil {
+ return err
+ }
+ // Reject inputs where single byte encoding should have been used.
+ if size == 1 && buffer[0] < 128 {
+ return ErrCanonSize
+ }
+ default:
+ // For large integers, a temporary buffer is needed.
+ buffer = make([]byte, size)
+ if err := s.readFull(buffer); err != nil {
+ return err
+ }
+ }
+
+ // Reject leading zero bytes.
+ if len(buffer) > 0 && buffer[0] == 0 {
+ return ErrCanonInt
+ }
+
+ // Set the integer bytes.
+ dst.SetBytes(buffer)
+
+ return nil
+}
+
+// Decode decodes a value and stores the result in the value pointed
+// to by val. Please see the documentation for the Decode function
+// to learn about the decoding rules.
+func (s *Stream) Decode(val interface{}) error {
+ if val == nil {
+ return errDecodeIntoNil
+ }
+
+ rval := reflect.ValueOf(val)
+ rtyp := rval.Type()
+
+ if rtyp.Kind() != reflect.Ptr {
+ return errNoPointer
+ }
+
+ if rval.IsNil() {
+ return errDecodeIntoNil
+ }
+
+ decoder, err := cachedDecoder(rtyp.Elem())
+ if err != nil {
+ return err
+ }
+
+ err = decoder(s, rval.Elem())
+
+ var decErr *decodeError
+ if errors.As(err, &decErr) && len(decErr.ctx) > 0 {
+ // Add decode target type to error so context has more meaning.
+ decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
+ }
+
+ return err
+}
+
+// Reset discards any information about the current decoding context
+// and starts reading from r. This method is meant to facilitate reuse
+// of a preallocated Stream across many decoding operations.
+//
+// If r does not also implement ByteReader, Stream will do its own
+// buffering.
+func (s *Stream) Reset(r io.Reader, inputLimit uint64) {
+ if inputLimit > 0 {
+ s.remaining = inputLimit
+ s.limited = true
+ } else {
+ // Attempt to automatically discover
+ // the limit when reading from a byte slice.
+ switch br := r.(type) {
+ case *bytes.Reader:
+ s.remaining = uint64(br.Len())
+ s.limited = true
+ case *bytes.Buffer:
+ s.remaining = uint64(br.Len())
+ s.limited = true
+ case *strings.Reader:
+ s.remaining = uint64(br.Len())
+ s.limited = true
+ default:
+ s.limited = false
+ }
+ }
+ // Wrap r with a buffer if it doesn't have one.
+ bufr, ok := r.(ByteReader)
+ if !ok {
+ bufr = bufio.NewReader(r)
+ }
+
+ s.r = bufr
+ // Reset the decoding context.
+ s.stack = s.stack[:0]
+ s.size = 0
+ s.kind = -1
+ s.kinderr = nil
+ s.byteval = 0
+ s.uintbuf = [32]byte{}
+}
+
+// Kind returns the kind and size of the next value in the
+// input stream.
+//
+// The returned size is the number of bytes that make up the value.
+// For kind == Byte, the size is zero because the value is
+// contained in the type tag.
+//
+// The first call to Kind will read size information from the input
+// reader and leave it positioned at the start of the actual bytes of
+// the value. Subsequent calls to Kind (until the value is decoded)
+// will not advance the input reader and return cached information.
+func (s *Stream) Kind() (kind Kind, size uint64, err error) {
+ if s.kind >= 0 {
+ return s.kind, s.size, s.kinderr
+ }
+
+ // Check for end of list. This needs to be done here because readKind
+ // checks against the list size, and would return the wrong error.
+ inList, listLimit := s.listLimit()
+ if inList && listLimit == 0 {
+ return 0, 0, EOL
+ }
+
+ // Read the actual size tag.
+ s.kind, s.size, s.kinderr = s.readKind()
+ if s.kinderr == nil {
+ // Check the data size of the value ahead against input limits. This
+ // is done here because many decoders require allocating an input
+ // buffer matching the value size. Checking it here protects those
+ // decoders from inputs declaring very large value size.
+ if inList && s.size > listLimit {
+ s.kinderr = ErrElemTooLarge
+ } else if s.limited && s.size > s.remaining {
+ s.kinderr = ErrValueTooLarge
+ }
+ }
+
+ return s.kind, s.size, s.kinderr
+}
+
+func (s *Stream) readKind() (kind Kind, size uint64, err error) {
+ b, err := s.readByte()
+ if err != nil {
+ if len(s.stack) == 0 {
+ // At toplevel, Adjust the error to actual EOF. io.EOF is
+ // used by callers to determine when to stop decoding.
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ err = io.EOF
+ } else if errors.Is(err, ErrValueTooLarge) {
+ err = io.EOF
+ }
+ }
+
+ return 0, 0, err
+ }
+
+ s.byteval = 0
+
+ switch {
+ case b < 0x80:
+ // For a single byte whose value is in the [0x00, 0x7F] range, that byte
+ // is its own RLP encoding.
+ s.byteval = b
+
+ return Byte, 0, nil
+ case b < 0xB8:
+ // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists
+ // of a single byte with value 0x80 plus the length of the string
+ // followed by the string. The range of the first byte is thus [0x80, 0xB7].
+ return String, uint64(b - 0x80), nil
+ case b < 0xC0:
+ // If a string is more than 55 bytes long, the RLP encoding consists of a
+ // single byte with value 0xB7 plus the length of the length of the
+ // string in binary form, followed by the length of the string, followed
+ // by the string. For example, a length-1024 string would be encoded as
+ // 0xB90400 followed by the string. The range of the first byte is thus
+ // [0xB8, 0xBF].
+ size, err = s.readUint(b - 0xB7)
+ if err == nil && size < 56 {
+ err = ErrCanonSize
+ }
+
+ return String, size, err
+ case b < 0xF8:
+ // If the total payload of a list (i.e. the combined length of all its
+ // items) is 0-55 bytes long, the RLP encoding consists of a single byte
+ // with value 0xC0 plus the length of the list followed by the
+ // concatenation of the RLP encodings of the items. The range of the
+ // first byte is thus [0xC0, 0xF7].
+ return List, uint64(b - 0xC0), nil
+ default:
+ // If the total payload of a list is more than 55 bytes long, the RLP
+ // encoding consists of a single byte with value 0xF7 plus the length of
+ // the length of the payload in binary form, followed by the length of
+ // the payload, followed by the concatenation of the RLP encodings of
+ // the items. The range of the first byte is thus [0xF8, 0xFF].
+ size, err = s.readUint(b - 0xF7)
+ if err == nil && size < 56 {
+ err = ErrCanonSize
+ }
+
+ return List, size, err
+ }
+}
+
+func (s *Stream) readUint(size byte) (uint64, error) {
+ switch size {
+ case 0:
+ s.kind = -1 // rearm Kind
+
+ return 0, nil
+ case 1:
+ b, err := s.readByte()
+
+ return uint64(b), err
+ default:
+ buffer := s.uintbuf[:8]
+ for i := range buffer {
+ buffer[i] = 0
+ }
+
+ start := int(8 - size)
+ if err := s.readFull(buffer[start:]); err != nil {
+ return 0, err
+ }
+
+ if buffer[start] == 0 {
+ // Note: readUint is also used to decode integer values.
+ // The error needs to be adjusted to become ErrCanonInt in this case.
+ return 0, ErrCanonSize
+ }
+
+ return binary.BigEndian.Uint64(buffer[:]), nil
+ }
+}
+
+// readFull reads into buf from the underlying stream.
+func (s *Stream) readFull(buf []byte) (err error) {
+ if err := s.willRead(uint64(len(buf))); err != nil {
+ return err
+ }
+
+ var nn, n int
+ for n < len(buf) && err == nil {
+ nn, err = s.r.Read(buf[n:])
+ n += nn
+ }
+
+ if errors.Is(err, io.EOF) {
+ if n < len(buf) {
+ err = io.ErrUnexpectedEOF
+ } else {
+ // Readers are allowed to give EOF even though the read succeeded.
+ // In such cases, we discard the EOF, like io.ReadFull() does.
+ err = nil
+ }
+ }
+
+ return err
+}
+
+// readByte reads a single byte from the underlying stream.
+func (s *Stream) readByte() (byte, error) {
+ if err := s.willRead(1); err != nil {
+ return 0, err
+ }
+
+ b, err := s.r.ReadByte()
+ if errors.Is(err, io.EOF) {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return b, err
+}
+
+// willRead is called before any read from the underlying stream. It checks
+// n against size limits, and updates the limits if n doesn't overflow them.
+func (s *Stream) willRead(n uint64) error {
+ s.kind = -1 // rearm Kind
+
+ if inList, limit := s.listLimit(); inList {
+ if n > limit {
+ return ErrElemTooLarge
+ }
+
+ s.stack[len(s.stack)-1] = limit - n
+ }
+
+ if s.limited {
+ if n > s.remaining {
+ return ErrValueTooLarge
+ }
+
+ s.remaining -= n
+ }
+
+ return nil
+}
+
+// listLimit returns the amount of data remaining in the innermost list.
+func (s *Stream) listLimit() (inList bool, limit uint64) {
+ if len(s.stack) == 0 {
+ return false, 0
+ }
+
+ return true, s.stack[len(s.stack)-1]
+}
diff --git a/helper/rlp/decode_tail_test.go b/helper/rlp/decode_tail_test.go
new file mode 100644
index 0000000000..884c1148b2
--- /dev/null
+++ b/helper/rlp/decode_tail_test.go
@@ -0,0 +1,49 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "bytes"
+ "fmt"
+)
+
+type structWithTail struct {
+ A, B uint
+ C []uint `rlp:"tail"`
+}
+
+func ExampleDecode_structTagTail() {
+ // In this example, the "tail" struct tag is used to decode lists of
+ // differing length into a struct.
+ var val structWithTail
+
+ err := Decode(bytes.NewReader([]byte{0xC4, 0x01, 0x02, 0x03, 0x04}), &val)
+ fmt.Printf("with 4 elements: err=%v val=%v\n", err, val)
+
+ err = Decode(bytes.NewReader([]byte{0xC6, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06}), &val)
+ fmt.Printf("with 6 elements: err=%v val=%v\n", err, val)
+
+ // Note that at least two list elements must be present to
+ // fill fields A and B:
+ err = Decode(bytes.NewReader([]byte{0xC1, 0x01}), &val)
+ fmt.Printf("with 1 element: err=%q\n", err)
+
+ // Output:
+ // with 4 elements: err= val={1 2 [3 4]}
+ // with 6 elements: err= val={1 2 [3 4 5 6]}
+ // with 1 element: err="rlp: too few elements for rlp.structWithTail"
+}
diff --git a/helper/rlp/decode_test.go b/helper/rlp/decode_test.go
new file mode 100644
index 0000000000..5fadec525c
--- /dev/null
+++ b/helper/rlp/decode_test.go
@@ -0,0 +1,1306 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "bytes"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestStreamKind(t *testing.T) {
+ tests := []struct {
+ input string
+ wantKind Kind
+ wantLen uint64
+ }{
+ {"00", Byte, 0},
+ {"01", Byte, 0},
+ {"7F", Byte, 0},
+ {"80", String, 0},
+ {"B7", String, 55},
+ {"B90400", String, 1024},
+ {"BFFFFFFFFFFFFFFFFF", String, ^uint64(0)},
+ {"C0", List, 0},
+ {"C8", List, 8},
+ {"F7", List, 55},
+ {"F90400", List, 1024},
+ {"FFFFFFFFFFFFFFFFFF", List, ^uint64(0)},
+ }
+
+ for i, test := range tests {
+ // using plainReader to inhibit input limit errors.
+ s := NewStream(newPlainReader(unhex(test.input)), 0)
+
+ kind, length, err := s.Kind()
+ if err != nil {
+ t.Errorf("test %d: Kind returned error: %v", i, err)
+
+ continue
+ }
+
+ if kind != test.wantKind {
+ t.Errorf("test %d: kind mismatch: got %d, want %d", i, kind, test.wantKind)
+ }
+
+ if length != test.wantLen {
+ t.Errorf("test %d: len mismatch: got %d, want %d", i, length, test.wantLen)
+ }
+ }
+}
+
+func TestNewListStream(t *testing.T) {
+ ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3)
+
+ if k, size, err := ls.Kind(); k != List || size != 3 || err != nil {
+ t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err)
+ }
+
+ if size, err := ls.List(); size != 3 || err != nil {
+ t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err)
+ }
+
+ for i := 0; i < 3; i++ {
+ if val, err := ls.Uint(); val != 1 || err != nil {
+ t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err)
+ }
+ }
+
+ if err := ls.ListEnd(); err != nil {
+ t.Errorf("ListEnd() returned %v, expected (3, nil)", err)
+ }
+}
+
+func TestStreamErrors(t *testing.T) {
+ withoutInputLimit := func(b []byte) *Stream {
+ return NewStream(newPlainReader(b), 0)
+ }
+ withCustomInputLimit := func(limit uint64) func([]byte) *Stream {
+ return func(b []byte) *Stream {
+ return NewStream(bytes.NewReader(b), limit)
+ }
+ }
+
+ type calls []string
+
+ tests := []struct {
+ string
+ calls
+ newStream func([]byte) *Stream // uses bytes.Reader if nil
+ error error
+ }{
+ {"C0", calls{"Bytes"}, nil, ErrExpectedString},
+ {"C0", calls{"Uint"}, nil, ErrExpectedString},
+ {"89000000000000000001", calls{"Uint"}, nil, errUintOverflow},
+ {"00", calls{"List"}, nil, ErrExpectedList},
+ {"80", calls{"List"}, nil, ErrExpectedList},
+ {"C0", calls{"List", "Uint"}, nil, EOL},
+ {"C8C9010101010101010101", calls{"List", "Kind"}, nil, ErrElemTooLarge},
+ {"C3C2010201", calls{"List", "List", "Uint", "Uint", "ListEnd", "Uint"}, nil, EOL},
+ {"00", calls{"ListEnd"}, nil, errNotInList},
+ {"C401020304", calls{"List", "Uint", "ListEnd"}, nil, errNotAtEOL},
+
+ // Non-canonical integers (e.g. leading zero bytes).
+ {"00", calls{"Uint"}, nil, ErrCanonInt},
+ {"820002", calls{"Uint"}, nil, ErrCanonInt},
+ {"8133", calls{"Uint"}, nil, ErrCanonSize},
+ {"817F", calls{"Uint"}, nil, ErrCanonSize},
+ {"8180", calls{"Uint"}, nil, nil},
+
+ // Non-valid boolean
+ {"02", calls{"Bool"}, nil, errors.New("rlp: invalid boolean value: 2")},
+
+ // Size tags must use the smallest possible encoding.
+ // Leading zero bytes in the size tag are also rejected.
+ {"8100", calls{"Uint"}, nil, ErrCanonSize},
+ {"8100", calls{"Bytes"}, nil, ErrCanonSize},
+ {"8101", calls{"Bytes"}, nil, ErrCanonSize},
+ {"817F", calls{"Bytes"}, nil, ErrCanonSize},
+ {"8180", calls{"Bytes"}, nil, nil},
+ {"B800", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"B90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"B90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"BA0002FFFF", calls{"Bytes"}, withoutInputLimit, ErrCanonSize},
+ {"F800", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"F90000", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"F90055", calls{"Kind"}, withoutInputLimit, ErrCanonSize},
+ {"FA0002FFFF", calls{"List"}, withoutInputLimit, ErrCanonSize},
+
+ // Expected EOF
+ {"", calls{"Kind"}, nil, io.EOF},
+ {"", calls{"Uint"}, nil, io.EOF},
+ {"", calls{"List"}, nil, io.EOF},
+ {"8180", calls{"Uint", "Uint"}, nil, io.EOF},
+ {"C0", calls{"List", "ListEnd", "List"}, nil, io.EOF},
+
+ {"", calls{"List"}, withoutInputLimit, io.EOF},
+ {"8180", calls{"Uint", "Uint"}, withoutInputLimit, io.EOF},
+ {"C0", calls{"List", "ListEnd", "List"}, withoutInputLimit, io.EOF},
+
+ // Input limit errors.
+ {"81", calls{"Bytes"}, nil, ErrValueTooLarge},
+ {"81", calls{"Uint"}, nil, ErrValueTooLarge},
+ {"81", calls{"Raw"}, nil, ErrValueTooLarge},
+ {"BFFFFFFFFFFFFFFFFFFF", calls{"Bytes"}, nil, ErrValueTooLarge},
+ {"C801", calls{"List"}, nil, ErrValueTooLarge},
+
+ // Test for list element size check overflow.
+ {"CD04040404FFFFFFFFFFFFFFFFFF0303", calls{"List", "Uint", "Uint", "Uint", "Uint", "List"}, nil, ErrElemTooLarge},
+
+ // Test for input limit overflow. Since we are counting the limit
+ // down toward zero in Stream.remaining, reading too far can overflow
+ // remaining to a large value, effectively disabling the limit.
+ {"C40102030401", calls{"Raw", "Uint"}, withCustomInputLimit(5), io.EOF},
+ {"C4010203048180", calls{"Raw", "Uint"}, withCustomInputLimit(6), ErrValueTooLarge},
+
+ // Check that the same calls are fine without a limit.
+ {"C40102030401", calls{"Raw", "Uint"}, withoutInputLimit, nil},
+ {"C4010203048180", calls{"Raw", "Uint"}, withoutInputLimit, nil},
+
+ // Unexpected EOF. This only happens when there is
+ // no input limit, so the reader needs to be 'dumbed down'.
+ {"81", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"81", calls{"Uint"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"BFFFFFFFFFFFFFFF", calls{"Bytes"}, withoutInputLimit, io.ErrUnexpectedEOF},
+ {"C801", calls{"List", "Uint", "Uint"}, withoutInputLimit, io.ErrUnexpectedEOF},
+
+ // This test verifies that the input position is advanced
+ // correctly when calling Bytes for empty strings. Kind can be called
+ // any number of times in between and doesn't advance.
+ {"C3808080", calls{
+ "List", // enter the list
+ "Bytes", // past first element
+
+ "Kind", "Kind", "Kind", // this shouldn't advance
+
+ "Bytes", // past second element
+
+ "Kind", "Kind", // can't hurt to try
+
+ "Bytes", // past final element
+ "Bytes", // this one should fail
+ }, nil, EOL},
+ }
+
+ const nilVal = ""
+
+testfor:
+ for i, test := range tests {
+ if test.newStream == nil {
+ test.newStream = func(b []byte) *Stream { return NewStream(bytes.NewReader(b), 0) }
+ }
+ s := test.newStream(unhex(test.string))
+ rs := reflect.ValueOf(s)
+ for j, call := range test.calls {
+ fval := rs.MethodByName(call)
+ ret := fval.Call(nil)
+ err := nilVal
+
+ if lastret := ret[len(ret)-1].Interface(); lastret != nil {
+ err = lastret.(error).Error()
+ }
+
+ if j == len(test.calls)-1 {
+ want := nilVal
+ if test.error != nil {
+ want = test.error.Error()
+ }
+ if err != want {
+ t.Log(test)
+ t.Errorf("test %d: last call (%s) error mismatch\ngot: %s\nwant: %s",
+ i, call, err, test.error)
+ }
+ } else if err != nilVal {
+ t.Log(test)
+ t.Errorf("test %d: call %d (%s) unexpected error: %q", i, j, call, err)
+
+ continue testfor
+ }
+ }
+ }
+}
+
+func TestStreamList(t *testing.T) {
+ s := NewStream(bytes.NewReader(unhex("C80102030405060708")), 0)
+
+ length, err := s.List()
+ if err != nil {
+ t.Fatalf("List error: %v", err)
+ }
+
+ if length != 8 {
+ t.Fatalf("List returned invalid length, got %d, want 8", length)
+ }
+
+ for i := uint64(1); i <= 8; i++ {
+ v, err := s.Uint()
+ if err != nil {
+ t.Fatalf("Uint error: %v", err)
+ }
+
+ if i != v {
+ t.Errorf("Uint returned wrong value, got %d, want %d", v, i)
+ }
+ }
+
+ if _, err := s.Uint(); !errors.Is(err, EOL) {
+ t.Errorf("Uint error mismatch, got %v, want %v", err, EOL)
+ }
+
+ if err = s.ListEnd(); err != nil {
+ t.Fatalf("ListEnd error: %v", err)
+ }
+}
+
+func TestStreamRaw(t *testing.T) {
+ tests := []struct {
+ input string
+ output string
+ }{
+ {
+ "C58401010101",
+ "8401010101",
+ },
+ {
+ "F842B84001010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
+ "B84001010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101",
+ },
+ }
+ for i, tt := range tests {
+ s := NewStream(bytes.NewReader(unhex(tt.input)), 0)
+ s.List()
+
+ want := unhex(tt.output)
+
+ raw, err := s.Raw()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(want, raw) {
+ t.Errorf("test %d: raw mismatch: got %x, want %x", i, raw, want)
+ }
+ }
+}
+
+func TestStreamReadBytes(t *testing.T) {
+ tests := []struct {
+ input string
+ size int
+ err string
+ }{
+ // kind List
+ {input: "C0", size: 1, err: "rlp: expected String or Byte"},
+ // kind Byte
+ {input: "04", size: 0, err: "input value has wrong size 1, want 0"},
+ {input: "04", size: 1},
+ {input: "04", size: 2, err: "input value has wrong size 1, want 2"},
+ // kind String
+ {input: "820102", size: 0, err: "input value has wrong size 2, want 0"},
+ {input: "820102", size: 1, err: "input value has wrong size 2, want 1"},
+ {input: "820102", size: 2},
+ {input: "820102", size: 3, err: "input value has wrong size 2, want 3"},
+ }
+
+ for _, test := range tests {
+ test := test
+ name := fmt.Sprintf("input_%s/size_%d", test.input, test.size)
+ t.Run(name, func(t *testing.T) {
+ s := NewStream(bytes.NewReader(unhex(test.input)), 0)
+ b := make([]byte, test.size)
+ err := s.ReadBytes(b)
+ if test.err == "" {
+ if err != nil {
+ t.Errorf("unexpected error %q", err)
+ }
+ } else {
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ } else if err.Error() != test.err {
+ t.Errorf("wrong error %q", err)
+ }
+ }
+ })
+ }
+}
+
+func TestDecodeErrors(t *testing.T) {
+ r := bytes.NewReader(nil)
+
+ if err := Decode(r, nil); !errors.Is(err, errDecodeIntoNil) {
+ t.Errorf("Decode(r, nil) error mismatch, got %q, want %q", err, errDecodeIntoNil)
+ }
+
+ var nilptr *struct{}
+ if err := Decode(r, nilptr); !errors.Is(err, errDecodeIntoNil) {
+ t.Errorf("Decode(r, nilptr) error mismatch, got %q, want %q", err, errDecodeIntoNil)
+ }
+
+ if err := Decode(r, struct{}{}); !errors.Is(err, errNoPointer) {
+ t.Errorf("Decode(r, struct{}{}) error mismatch, got %q, want %q", err, errNoPointer)
+ }
+
+ expectErr := "rlp: type chan bool is not RLP-serializable"
+ if err := Decode(r, new(chan bool)); err == nil || err.Error() != expectErr {
+ t.Errorf("Decode(r, new(chan bool)) error mismatch, got %q, want %q", err, expectErr)
+ }
+
+ if err := Decode(r, new(uint)); !errors.Is(err, io.EOF) {
+ t.Errorf("Decode(r, new(int)) error mismatch, got %q, want %q", err, io.EOF)
+ }
+}
+
+type decodeTest struct {
+ input string
+ ptr interface{}
+ value interface{}
+ error string
+}
+
+type simplestruct struct {
+ A uint
+ B string
+}
+
+type recstruct struct {
+ I uint
+ Child *recstruct `rlp:"nil"`
+}
+
+type bigIntStruct struct {
+ I *big.Int
+ B string
+}
+
+type invalidNilTag struct {
+ X []byte `rlp:"nil"`
+}
+
+type invalidTail1 struct {
+ A uint `rlp:"tail"`
+ B string
+}
+
+type invalidTail2 struct {
+ A uint
+ B string `rlp:"tail"`
+}
+
+type tailRaw struct {
+ A uint
+ Tail []RawValue `rlp:"tail"`
+}
+
+type tailUint struct {
+ A uint
+ Tail []uint `rlp:"tail"`
+}
+
+type tailPrivateFields struct {
+ A uint
+ Tail []uint `rlp:"tail"`
+ x, y bool //lint:ignore U1000 unused fields required for testing purposes.
+}
+
+type nilListUint struct {
+ X *uint `rlp:"nilList"`
+}
+
+type nilStringSlice struct {
+ X *[]uint `rlp:"nilString"`
+}
+
+type intField struct {
+ X int
+}
+
+type optionalFields struct {
+ A uint
+ B uint `rlp:"optional"`
+ C uint `rlp:"optional"`
+}
+
+type optionalAndTailField struct {
+ A uint
+ B uint `rlp:"optional"`
+ Tail []uint `rlp:"tail"`
+}
+
+type optionalBigIntField struct {
+ A uint
+ B *big.Int `rlp:"optional"`
+}
+
+type optionalPtrField struct {
+ A uint
+ B *[3]byte `rlp:"optional"`
+}
+
+type nonOptionalPtrField struct {
+ A uint
+ B *[3]byte
+}
+
+type multipleOptionalFields struct {
+ A *[3]byte `rlp:"optional"`
+ B *[3]byte `rlp:"optional"`
+}
+
+type optionalPtrFieldNil struct {
+ A uint
+ B *[3]byte `rlp:"optional,nil"`
+}
+
+type ignoredField struct {
+ A uint
+ B uint `rlp:"-"`
+ C uint
+}
+
+var (
+ veryBigInt = new(big.Int).Add(
+ new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
+ big.NewInt(0xFFFF),
+ )
+ veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil)
+)
+
+var decodeTests = []decodeTest{
+ // booleans
+ {input: "01", ptr: new(bool), value: true},
+ {input: "80", ptr: new(bool), value: false},
+ {input: "02", ptr: new(bool), error: "rlp: invalid boolean value: 2"},
+
+ // integers
+ {input: "05", ptr: new(uint32), value: uint32(5)},
+ {input: "80", ptr: new(uint32), value: uint32(0)},
+ {input: "820505", ptr: new(uint32), value: uint32(0x0505)},
+ {input: "83050505", ptr: new(uint32), value: uint32(0x050505)},
+ {input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)},
+ {input: "850505050505", ptr: new(uint32), error: "rlp: input string too long for uint32"},
+ {input: "C0", ptr: new(uint32), error: "rlp: expected input string or byte for uint32"},
+ {input: "00", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"},
+ {input: "8105", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"},
+ {input: "820004", ptr: new(uint32), error: "rlp: non-canonical integer (leading zero bytes) for uint32"},
+ {input: "B8020004", ptr: new(uint32), error: "rlp: non-canonical size information for uint32"},
+
+ // slices
+ {input: "C0", ptr: new([]uint), value: []uint{}},
+ {input: "C80102030405060708", ptr: new([]uint), value: []uint{1, 2, 3, 4, 5, 6, 7, 8}},
+ {input: "F8020004", ptr: new([]uint), error: "rlp: non-canonical size information for []uint"},
+
+ // arrays
+ {input: "C50102030405", ptr: new([5]uint), value: [5]uint{1, 2, 3, 4, 5}},
+ {input: "C0", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"},
+ {input: "C102", ptr: new([5]uint), error: "rlp: input list has too few elements for [5]uint"},
+ {input: "C6010203040506", ptr: new([5]uint), error: "rlp: input list has too many elements for [5]uint"},
+ {input: "F8020004", ptr: new([5]uint), error: "rlp: non-canonical size information for [5]uint"},
+
+ // zero sized arrays
+ {input: "C0", ptr: new([0]uint), value: [0]uint{}},
+ {input: "C101", ptr: new([0]uint), error: "rlp: input list has too many elements for [0]uint"},
+
+ // byte slices
+ {input: "01", ptr: new([]byte), value: []byte{1}},
+ {input: "80", ptr: new([]byte), value: []byte{}},
+ {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
+ {input: "C0", ptr: new([]byte), error: "rlp: expected input string or byte for []uint8"},
+ {input: "8105", ptr: new([]byte), error: "rlp: non-canonical size information for []uint8"},
+
+ // byte arrays
+ {input: "02", ptr: new([1]byte), value: [1]byte{2}},
+ {input: "8180", ptr: new([1]byte), value: [1]byte{128}},
+ {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
+
+ // byte array errors
+ {input: "02", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "80", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "820000", ptr: new([5]byte), error: "rlp: input string too short for [5]uint8"},
+ {input: "C0", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"},
+ {input: "C3010203", ptr: new([5]byte), error: "rlp: expected input string or byte for [5]uint8"},
+ {input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"},
+ {input: "8105", ptr: new([1]byte), error: "rlp: non-canonical size information for [1]uint8"},
+ {input: "817F", ptr: new([1]byte), error: "rlp: non-canonical size information for [1]uint8"},
+
+ // zero sized byte arrays
+ {input: "80", ptr: new([0]byte), value: [0]byte{}},
+ {input: "01", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"},
+ {input: "8101", ptr: new([0]byte), error: "rlp: input string too long for [0]uint8"},
+
+ // strings
+ {input: "00", ptr: new(string), value: "\000"},
+ {input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"},
+ {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"},
+
+ // big ints
+ {input: "80", ptr: new(*big.Int), value: big.NewInt(0)},
+ {input: "01", ptr: new(*big.Int), value: big.NewInt(1)},
+ {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt},
+ {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt},
+ {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works
+ {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"},
+ {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
+ {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"},
+ {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"},
+
+ // structs
+ {
+ input: "C50583343434",
+ ptr: new(simplestruct),
+ value: simplestruct{5, "444"},
+ },
+ {
+ input: "C601C402C203C0",
+ ptr: new(recstruct),
+ value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
+ },
+ {
+ // This checks that empty big.Int works correctly in struct context. It's easy to
+ // miss the update of s.kind for this case, so it needs its own test.
+ input: "C58083343434",
+ ptr: new(bigIntStruct),
+ value: bigIntStruct{new(big.Int), "444"},
+ },
+
+ // struct errors
+ {
+ input: "C0",
+ ptr: new(simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct",
+ },
+ {
+ input: "C105",
+ ptr: new(simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct",
+ },
+ {
+ input: "C7C50583343434C0",
+ ptr: new([]*simplestruct),
+ error: "rlp: too few elements for rlp.simplestruct, decoding into ([]*rlp.simplestruct)[1]",
+ },
+ {
+ input: "83222222",
+ ptr: new(simplestruct),
+ error: "rlp: expected input list for rlp.simplestruct",
+ },
+ {
+ input: "C3010101",
+ ptr: new(simplestruct),
+ error: "rlp: input list has too many elements for rlp.simplestruct",
+ },
+ {
+ input: "C501C3C00000",
+ ptr: new(recstruct),
+ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
+ },
+ {
+ input: "C103",
+ ptr: new(intField),
+ error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)",
+ },
+ {
+ input: "C50102C20102",
+ ptr: new(tailUint),
+ error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]",
+ },
+ {
+ input: "C0",
+ ptr: new(invalidNilTag),
+ error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`,
+ },
+
+ // struct tag "tail"
+ {
+ input: "C3010203",
+ ptr: new(tailRaw),
+ value: tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}},
+ },
+ {
+ input: "C20102",
+ ptr: new(tailRaw),
+ value: tailRaw{A: 1, Tail: []RawValue{unhex("02")}},
+ },
+ {
+ input: "C101",
+ ptr: new(tailRaw),
+ value: tailRaw{A: 1, Tail: []RawValue{}},
+ },
+ {
+ input: "C3010203",
+ ptr: new(tailPrivateFields),
+ value: tailPrivateFields{A: 1, Tail: []uint{2, 3}},
+ },
+ {
+ input: "C0",
+ ptr: new(invalidTail1),
+ error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`,
+ },
+ {
+ input: "C0",
+ ptr: new(invalidTail2),
+ error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`,
+ },
+
+ // struct tag "-"
+ {
+ input: "C20102",
+ ptr: new(ignoredField),
+ value: ignoredField{A: 1, C: 2},
+ },
+
+ // struct tag "nilList"
+ {
+ input: "C180",
+ ptr: new(nilListUint),
+ error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X",
+ },
+ {
+ input: "C1C0",
+ ptr: new(nilListUint),
+ value: nilListUint{},
+ },
+ {
+ input: "C103",
+ ptr: new(nilListUint),
+ value: func() interface{} {
+ v := uint(3)
+
+ return nilListUint{X: &v}
+ }(),
+ },
+
+ // struct tag "nilString"
+ {
+ input: "C1C0",
+ ptr: new(nilStringSlice),
+ error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X",
+ },
+ {
+ input: "C180",
+ ptr: new(nilStringSlice),
+ value: nilStringSlice{},
+ },
+ {
+ input: "C2C103",
+ ptr: new(nilStringSlice),
+ value: nilStringSlice{X: &[]uint{3}},
+ },
+
+ // struct tag "optional"
+ {
+ input: "C101",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 0, 0},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 2, 0},
+ },
+ {
+ input: "C3010203",
+ ptr: new(optionalFields),
+ value: optionalFields{1, 2, 3},
+ },
+ {
+ input: "C401020304",
+ ptr: new(optionalFields),
+ error: "rlp: input list has too many elements for rlp.optionalFields",
+ },
+ {
+ input: "C101",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
+ },
+ {
+ input: "C401020304",
+ ptr: new(optionalAndTailField),
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}},
+ },
+ {
+ input: "C101",
+ ptr: new(optionalBigIntField),
+ value: optionalBigIntField{A: 1, B: nil},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalBigIntField),
+ value: optionalBigIntField{A: 1, B: big.NewInt(2)},
+ },
+ {
+ input: "C101",
+ ptr: new(optionalPtrField),
+ value: optionalPtrField{A: 1},
+ },
+ {
+ input: "C20180", // not accepted because "optional" doesn't enable "nil"
+ ptr: new(optionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B",
+ },
+ {
+ input: "C50183010203",
+ ptr: new(optionalPtrField),
+ value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}},
+ },
+ {
+ // all optional fields nil
+ input: "C0",
+ ptr: new(multipleOptionalFields),
+ value: multipleOptionalFields{A: nil, B: nil},
+ },
+ {
+ // all optional fields set
+ input: "C88301020383010203",
+ ptr: new(multipleOptionalFields),
+ value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}},
+ },
+ {
+ // nil optional field appears before a non-nil one
+ input: "C58083010203",
+ ptr: new(multipleOptionalFields),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A",
+ },
+ {
+ // decode a nil ptr into a ptr that is not nil or not optional
+ input: "C20180",
+ ptr: new(nonOptionalPtrField),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B",
+ },
+ {
+ input: "C101",
+ ptr: new(optionalPtrFieldNil),
+ value: optionalPtrFieldNil{A: 1},
+ },
+ {
+ input: "C20180", // accepted because "nil" tag allows empty input
+ ptr: new(optionalPtrFieldNil),
+ value: optionalPtrFieldNil{A: 1},
+ },
+ {
+ input: "C20102",
+ ptr: new(optionalPtrFieldNil),
+ error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B",
+ },
+
+ // struct tag "optional" field clearing
+ {
+ input: "C101",
+ ptr: &optionalFields{A: 9, B: 8, C: 7},
+ value: optionalFields{A: 1, B: 0, C: 0},
+ },
+ {
+ input: "C20102",
+ ptr: &optionalFields{A: 9, B: 8, C: 7},
+ value: optionalFields{A: 1, B: 2, C: 0},
+ },
+ {
+ input: "C20102",
+ ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}},
+ value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}},
+ },
+ {
+ input: "C101",
+ ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}},
+ value: optionalPtrField{A: 1},
+ },
+
+ // RawValue
+ {input: "01", ptr: new(RawValue), value: RawValue(unhex("01"))},
+ {input: "82FFFF", ptr: new(RawValue), value: RawValue(unhex("82FFFF"))},
+ {input: "C20102", ptr: new([]RawValue), value: []RawValue{unhex("01"), unhex("02")}},
+
+ // pointers
+ {input: "00", ptr: new(*[]byte), value: &[]byte{0}},
+ {input: "80", ptr: new(*uint), value: uintp(0)},
+ {input: "C0", ptr: new(*uint), error: "rlp: expected input string or byte for uint"},
+ {input: "07", ptr: new(*uint), value: uintp(7)},
+ {input: "817F", ptr: new(*uint), error: "rlp: non-canonical size information for uint"},
+ {input: "8180", ptr: new(*uint), value: uintp(0x80)},
+ {input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
+ {input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
+
+ // check that input position is advanced also for empty values.
+ {input: "C3808005", ptr: new([]*uint), value: []*uint{uintp(0), uintp(0), uintp(5)}},
+
+ // interface{}
+ {input: "00", ptr: new(interface{}), value: []byte{0}},
+ {input: "01", ptr: new(interface{}), value: []byte{1}},
+ {input: "80", ptr: new(interface{}), value: []byte{}},
+ {input: "850505050505", ptr: new(interface{}), value: []byte{5, 5, 5, 5, 5}},
+ {input: "C0", ptr: new(interface{}), value: []interface{}{}},
+ {input: "C50183040404", ptr: new(interface{}), value: []interface{}{[]byte{1}, []byte{4, 4, 4}}},
+ {
+ input: "C3010203",
+ ptr: new([]io.Reader),
+ error: "rlp: type io.Reader is not RLP-serializable",
+ },
+
+ // fuzzer crashes
+ {
+ input: "c330f9c030f93030ce3030303030303030bd303030303030",
+ ptr: new(interface{}),
+ error: "rlp: element is larger than containing list",
+ },
+}
+
+func uintp(i uint) *uint { return &i }
+
+func runTests(t *testing.T, decode func([]byte, interface{}) error) {
+ t.Helper()
+
+ for i, test := range decodeTests {
+ input, err := hex.DecodeString(test.input)
+ if err != nil {
+ t.Errorf("test %d: invalid hex input %q", i, test.input)
+
+ continue
+ }
+
+ err = decode(input, test.ptr)
+ if err != nil && test.error == "" {
+ t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q",
+ i, err, test.ptr, test.input)
+
+ continue
+ }
+
+ if test.error != "" && fmt.Sprint(err) != test.error {
+ t.Errorf("test %d: Decode error mismatch\ngot %v\nwant %v\ndecoding into %T\ninput %q",
+ i, err, test.error, test.ptr, test.input)
+
+ continue
+ }
+
+ deref := reflect.ValueOf(test.ptr).Elem().Interface()
+ if err == nil && !reflect.DeepEqual(deref, test.value) {
+ t.Errorf("test %d: value mismatch\ngot %#v\nwant %#v\ndecoding into %T\ninput %q",
+ i, deref, test.value, test.ptr, test.input)
+ }
+ }
+}
+
+func TestDecodeWithByteReader(t *testing.T) {
+ runTests(t, func(input []byte, into interface{}) error {
+ return Decode(bytes.NewReader(input), into)
+ })
+}
+
+func testDecodeWithEncReader(t *testing.T, n int) {
+ t.Helper()
+
+ s := strings.Repeat("0", n)
+ _, r, _ := EncodeToReader(s)
+
+ var decoded string
+
+ err := Decode(r, &decoded)
+ if err != nil {
+ t.Errorf("Unexpected decode error with n=%v: %v", n, err)
+ }
+
+ if decoded != s {
+ t.Errorf("Decode mismatch with n=%v", n)
+ }
+}
+
+// This is a regression test checking that decoding from encReader
+// works for RLP values of size 8192 bytes or more.
+func TestDecodeWithEncReader(t *testing.T) {
+ testDecodeWithEncReader(t, 8188) // length with header is 8191
+ testDecodeWithEncReader(t, 8189) // length with header is 8192
+}
+
+// plainReader reads from a byte slice but does not
+// implement ReadByte. It is also not recognized by the
+// size validation. This is useful to test how the decoder
+// behaves on a non-buffered input stream.
+type plainReader []byte
+
+func newPlainReader(b []byte) io.Reader {
+ return (*plainReader)(&b)
+}
+
+func (r *plainReader) Read(buf []byte) (n int, err error) {
+ if len(*r) == 0 {
+ return 0, io.EOF
+ }
+
+ n = copy(buf, *r)
+ *r = (*r)[n:]
+
+ return n, nil
+}
+
+func TestDecodeWithNonByteReader(t *testing.T) {
+ runTests(t, func(input []byte, into interface{}) error {
+ return Decode(newPlainReader(input), into)
+ })
+}
+
+func TestDecodeStreamReset(t *testing.T) {
+ s := NewStream(nil, 0)
+
+ runTests(t, func(input []byte, into interface{}) error {
+ s.Reset(bytes.NewReader(input), 0)
+
+ return s.Decode(into)
+ })
+}
+
+type testDecoder struct{ called bool }
+
+func (t *testDecoder) DecodeRLP(s *Stream) error {
+ if _, err := s.Uint(); err != nil {
+ return err
+ }
+
+ t.called = true
+
+ return nil
+}
+
+func TestDecodeDecoder(t *testing.T) {
+ var s struct {
+ T1 testDecoder
+ T2 *testDecoder
+ T3 **testDecoder
+ }
+
+ if err := Decode(bytes.NewReader(unhex("C3010203")), &s); err != nil {
+ t.Fatalf("Decode error: %v", err)
+ }
+
+ if !s.T1.called {
+ t.Errorf("DecodeRLP was not called for (non-pointer) testDecoder")
+ }
+
+ if s.T2 == nil {
+ t.Errorf("*testDecoder has not been allocated")
+ } else if !s.T2.called {
+ t.Errorf("DecodeRLP was not called for *testDecoder")
+ }
+
+ if s.T3 == nil || *s.T3 == nil {
+ t.Errorf("**testDecoder has not been allocated")
+ } else if !(*s.T3).called {
+ t.Errorf("DecodeRLP was not called for **testDecoder")
+ }
+}
+
+func TestDecodeDecoderNilPointer(t *testing.T) {
+ var s struct {
+ T1 *testDecoder `rlp:"nil"`
+ T2 *testDecoder
+ }
+
+ if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil {
+ t.Fatalf("Decode error: %v", err)
+ }
+
+ if s.T1 != nil {
+ t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called)
+ }
+
+ if s.T2 == nil || !s.T2.called {
+ t.Errorf("decoder T2 not allocated/called")
+ }
+}
+
+type byteDecoder byte
+
+func (bd *byteDecoder) DecodeRLP(s *Stream) error {
+ _, err := s.Uint()
+ *bd = 255
+
+ return err
+}
+
+func (bd byteDecoder) called() bool {
+ return bd == 255
+}
+
+// This test verifies that the byte slice/byte array logic
+// does not kick in for element types implementing Decoder.
+func TestDecoderInByteSlice(t *testing.T) {
+ var slice []byteDecoder
+ if err := Decode(bytes.NewReader(unhex("C101")), &slice); err != nil {
+ t.Errorf("unexpected Decode error %v", err)
+ } else if !slice[0].called() {
+ t.Errorf("DecodeRLP not called for slice element")
+ }
+
+ var array [1]byteDecoder
+ if err := Decode(bytes.NewReader(unhex("C101")), &array); err != nil {
+ t.Errorf("unexpected Decode error %v", err)
+ } else if !array[0].called() {
+ t.Errorf("DecodeRLP not called for array element")
+ }
+}
+
+type unencodableDecoder func()
+
+func (f *unencodableDecoder) DecodeRLP(s *Stream) error {
+ if _, err := s.List(); err != nil {
+ return err
+ }
+
+ if err := s.ListEnd(); err != nil {
+ return err
+ }
+
+ *f = func() {}
+
+ return nil
+}
+
+func TestDecoderFunc(t *testing.T) {
+ var x func()
+ if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil {
+ t.Fatal(err)
+ }
+
+ x()
+}
+
+// This tests the validity checks for fields with struct tag "optional".
+func TestInvalidOptionalField(t *testing.T) {
+ type (
+ invalid1 struct {
+ A uint `rlp:"optional"`
+ B uint
+ }
+ invalid2 struct {
+ T []uint `rlp:"tail,optional"`
+ }
+ invalid3 struct {
+ T []uint `rlp:"optional,tail"`
+ }
+ )
+
+ tests := []struct {
+ v interface{}
+ err string
+ }{
+ {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`},
+ {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`},
+ {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`},
+ }
+ for _, test := range tests {
+ err := DecodeBytes(unhex("C20102"), test.v)
+ if err == nil {
+ t.Errorf("no error for %T", test.v)
+ } else if err.Error() != test.err {
+ t.Errorf("wrong error for %T: %v", test.v, err.Error())
+ }
+ }
+}
+
+func ExampleDecode() {
+ input, _ := hex.DecodeString("C90A1486666F6F626172")
+
+ type example struct {
+ A, B uint
+ String string
+ }
+
+ var s example
+
+ err := Decode(bytes.NewReader(input), &s)
+ if err != nil {
+ fmt.Printf("Error: %v\n", err)
+ } else {
+ // Output:
+ // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"}
+ fmt.Printf("Decoded value: %#v\n", s)
+ }
+}
+
+func ExampleDecode_structTagNil() {
+ // In this example, we'll use the "nil" struct tag to change
+ // how a pointer-typed field is decoded. The input contains an RLP
+ // list of one element, an empty string.
+ input := []byte{0xC1, 0x80}
+
+ // This type uses the normal rules.
+ // The empty input string is decoded as a pointer to an empty Go string.
+ var normalRules struct {
+ String *string
+ }
+
+ Decode(bytes.NewReader(input), &normalRules)
+ fmt.Printf("normal: String = %q\n", *normalRules.String)
+
+ // This type uses the struct tag.
+ // The empty input string is decoded as a nil pointer.
+ var withEmptyOK struct {
+ String *string `rlp:"nil"`
+ }
+
+ Decode(bytes.NewReader(input), &withEmptyOK)
+ fmt.Printf("with nil tag: String = %v\n", withEmptyOK.String)
+
+ // Output:
+ // normal: String = ""
+ // with nil tag: String =
+}
+
+func ExampleStream() {
+ input, _ := hex.DecodeString("C90A1486666F6F626172")
+ s := NewStream(bytes.NewReader(input), 0)
+
+ // Check what kind of value lies ahead
+ kind, size, _ := s.Kind()
+ fmt.Printf("Kind: %v size:%d\n", kind, size)
+
+ // Enter the list
+ if _, err := s.List(); err != nil {
+ fmt.Printf("List error: %v\n", err)
+
+ return
+ }
+
+ // Decode elements
+ fmt.Println(s.Uint())
+ fmt.Println(s.Uint())
+ fmt.Println(s.Bytes())
+
+ // Acknowledge end of list
+ if err := s.ListEnd(); err != nil {
+ fmt.Printf("ListEnd error: %v\n", err)
+ }
+ // Output:
+ // Kind: List size:9
+ // 10
+ // 20
+ // [102 111 111 98 97 114]
+}
+
+func BenchmarkDecodeUints(b *testing.B) {
+ enc := encodeTestSlice(90000)
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ var s []uint
+
+ r := bytes.NewReader(enc)
+
+ if err := Decode(r, &s); err != nil {
+ b.Fatalf("Decode error: %v", err)
+ }
+ }
+}
+
+func BenchmarkDecodeUintsReused(b *testing.B) {
+ enc := encodeTestSlice(100000)
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var s []uint
+
+ for i := 0; i < b.N; i++ {
+ r := bytes.NewReader(enc)
+ if err := Decode(r, &s); err != nil {
+ b.Fatalf("Decode error: %v", err)
+ }
+ }
+}
+
+func BenchmarkDecodeByteArrayStruct(b *testing.B) {
+ enc, err := EncodeToBytes(&byteArrayStruct{})
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var out byteArrayStruct
+
+ for i := 0; i < b.N; i++ {
+ if err := DecodeBytes(enc, &out); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkDecodeBigInts(b *testing.B) {
+ ints := make([]*big.Int, 200)
+ for i := range ints {
+ // 2 ^ i
+ ints[i] = new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(i)), nil)
+ }
+
+ enc, err := EncodeToBytes(ints)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.SetBytes(int64(len(enc)))
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ var out []*big.Int
+
+ for i := 0; i < b.N; i++ {
+ if err := DecodeBytes(enc, &out); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func encodeTestSlice(n uint) []byte {
+ s := make([]uint, n)
+ for i := uint(0); i < n; i++ {
+ s[i] = i
+ }
+
+ b, err := EncodeToBytes(s)
+ if err != nil {
+ panic(fmt.Sprintf("encode error: %v", err))
+ }
+
+ return b
+}
+
+func unhex(str string) []byte {
+ b, err := hex.DecodeString(strings.ReplaceAll(str, " ", ""))
+ if err != nil {
+ panic(fmt.Sprintf("invalid hex string: %q", str))
+ }
+
+ return b
+}
diff --git a/helper/rlp/doc.go b/helper/rlp/doc.go
new file mode 100644
index 0000000000..eeeee9a43a
--- /dev/null
+++ b/helper/rlp/doc.go
@@ -0,0 +1,158 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+/*
+Package rlp implements the RLP serialization format.
+
+The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of
+binary data, and RLP is the main encoding method used to serialize objects in Ethereum.
+The only purpose of RLP is to encode structure; encoding specific atomic data types (eg.
+strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be
+represented in big endian binary form with no leading zeroes (thus making the integer
+value zero equivalent to the empty string).
+
+RLP values are distinguished by a type tag. The type tag precedes the value in the input
+stream and defines the size and kind of the bytes that follow.
+
+# Encoding Rules
+
+Package rlp uses reflection and encodes RLP based on the Go type of the value.
+
+If the type implements the Encoder interface, Encode calls EncodeRLP. It does not
+call EncodeRLP on nil pointer values.
+
+To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct
+type, slice or array always encodes as an empty RLP list unless the slice or array has
+element type byte. A nil pointer to any other value encodes as the empty string.
+
+Struct values are encoded as an RLP list of all their encoded public fields. Recursive
+struct types are supported.
+
+To encode slices and arrays, the elements are encoded as an RLP list of the value's
+elements. Note that arrays and slices with element type uint8 or byte are always encoded
+as an RLP string.
+
+A Go string is encoded as an RLP string.
+
+An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP
+string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...)
+are not supported and will return an error when encoding.
+
+Boolean values are encoded as the unsigned integers zero (false) and one (true).
+
+An interface value encodes as the value contained in the interface.
+
+Floating point numbers, maps, channels and functions are not supported.
+
+# Decoding Rules
+
+Decoding uses the following type-dependent rules:
+
+If the type implements the Decoder interface, DecodeRLP is called.
+
+To decode into a pointer, the value will be decoded as the element type of the pointer. If
+the pointer is nil, a new value of the pointer's element type is allocated. If the pointer
+is non-nil, the existing value will be reused. Note that package rlp never leaves a
+pointer-type struct field as nil unless one of the "nil" struct tags is present.
+
+To decode into a struct, decoding expects the input to be an RLP list. The decoded
+elements of the list are assigned to each public field in the order given by the struct's
+definition. The input list must contain an element for each decoded field. Decoding
+returns an error if there are too few or too many elements for the struct.
+
+To decode into a slice, the input must be a list and the resulting slice will contain the
+input elements in order. For byte slices, the input must be an RLP string. Array types
+decode similarly, with the additional restriction that the number of input elements (or
+bytes) must match the array's defined length.
+
+To decode into a Go string, the input must be an RLP string. The input bytes are taken
+as-is and will not necessarily be valid UTF-8.
+
+To decode into an unsigned integer type, the input must also be an RLP string. The bytes
+are interpreted as a big endian representation of the integer. If the RLP string is larger
+than the bit size of the type, decoding will return an error. Decode also supports
+*big.Int. There is no size limit for big integers.
+
+To decode into a boolean, the input must contain an unsigned integer of value zero (false)
+or one (true).
+
+To decode into an interface value, one of these types is stored in the value:
+
+ []interface{}, for RLP lists
+ []byte, for RLP strings
+
+Non-empty interface types are not supported when decoding.
+Signed integers, floating point numbers, maps, channels and functions cannot be decoded into.
+
+# Struct Tags
+
+As with other encoding packages, the "-" tag ignores fields.
+
+ type StructWithIgnoredField struct{
+ Ignored uint `rlp:"-"`
+ Field uint
+ }
+
+Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping
+of fields to list elements. The "tail" tag, which may only be used on the last exported
+struct field, allows slurping up any excess list elements into a slice.
+
+ type StructWithTail struct{
+ Field uint
+ Tail []string `rlp:"tail"`
+ }
+
+The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is
+used on a struct field, all subsequent public fields must also be declared optional.
+
+When encoding a struct with optional fields, the output RLP list contains all values up to
+the last non-zero optional field.
+
+When decoding into a struct, optional fields may be omitted from the end of the input
+list. For the example below, this means input lists of one, two, or three elements are
+accepted.
+
+ type StructWithOptionalFields struct{
+ Required uint
+ Optional1 uint `rlp:"optional"`
+ Optional2 uint `rlp:"optional"`
+ }
+
+The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change
+the decoding rules for the field type. For regular pointer fields without the "nil" tag,
+input values must always match the required input length exactly and the decoder does not
+produce nil values. When the "nil" tag is set, input values of size zero decode as a nil
+pointer. This is especially useful for recursive types.
+
+ type StructWithNilField struct {
+ Field *[3]byte `rlp:"nil"`
+ }
+
+In the example above, Field allows two possible input sizes. For input 0xC180 (a list
+containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a
+list containing a 3-byte string), Field is set to a non-nil array pointer.
+
+RLP supports two kinds of empty values: empty lists and empty strings. When using the
+"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field
+whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice
+expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP
+list.
+
+The choice of null value can be made explicit with the "nilList" and "nilString" struct
+tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind
+defined by the tag.
+*/
+package rlp
diff --git a/helper/rlp/encbuffer.go b/helper/rlp/encbuffer.go
new file mode 100644
index 0000000000..976476a91d
--- /dev/null
+++ b/helper/rlp/encbuffer.go
@@ -0,0 +1,433 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "io"
+ "math/big"
+ "reflect"
+ "sync"
+)
+
+type encBuffer struct {
+ str []byte // string data, contains everything except list headers
+ lheads []listhead // all list headers
+ lhsize int // sum of sizes of all encoded list headers
+ sizebuf [9]byte // auxiliary buffer for uint encoding
+}
+
+// The global encBuffer pool.
+var encBufferPool = sync.Pool{
+ New: func() interface{} { return new(encBuffer) },
+}
+
+func getEncBuffer() *encBuffer {
+ buf, _ := encBufferPool.Get().(*encBuffer)
+ buf.reset()
+
+ return buf
+}
+
+func (buf *encBuffer) reset() {
+ buf.lhsize = 0
+ buf.str = buf.str[:0]
+ buf.lheads = buf.lheads[:0]
+}
+
+// size returns the length of the encoded data.
+func (buf *encBuffer) size() int {
+ return len(buf.str) + buf.lhsize
+}
+
+// makeBytes creates the encoder output.
+func (buf *encBuffer) makeBytes() []byte {
+ out := make([]byte, buf.size())
+ buf.copyTo(out)
+
+ return out
+}
+
+func (buf *encBuffer) copyTo(dst []byte) {
+ strpos := 0
+ pos := 0
+
+ for _, head := range buf.lheads {
+ // write string data before header
+ n := copy(dst[pos:], buf.str[strpos:head.offset])
+ pos += n
+ strpos += n
+ // write the header
+ enc := head.encode(dst[pos:])
+ pos += len(enc)
+ }
+ // copy string data after the last list header
+ copy(dst[pos:], buf.str[strpos:])
+}
+
+// writeTo writes the encoder output to w.
+func (buf *encBuffer) writeTo(w io.Writer) (err error) {
+ strpos := 0
+
+ for _, head := range buf.lheads {
+ // write string data before header
+ if head.offset-strpos > 0 {
+ n, err := w.Write(buf.str[strpos:head.offset])
+ strpos += n
+
+ if err != nil {
+ return err
+ }
+ }
+ // write the header
+ enc := head.encode(buf.sizebuf[:])
+ if _, err = w.Write(enc); err != nil {
+ return err
+ }
+ }
+
+ if strpos < len(buf.str) {
+ // write string data after the last list header
+ _, err = w.Write(buf.str[strpos:])
+ }
+
+ return err
+}
+
+// Write implements io.Writer and appends b directly to the output.
+func (buf *encBuffer) Write(b []byte) (int, error) {
+ buf.str = append(buf.str, b...)
+
+ return len(b), nil
+}
+
+// writeBool writes b as the integer 0 (false) or 1 (true).
+func (buf *encBuffer) writeBool(b bool) {
+ if b {
+ buf.str = append(buf.str, 0x01)
+ } else {
+ buf.str = append(buf.str, 0x80)
+ }
+}
+
+func (buf *encBuffer) writeUint64(i uint64) {
+ if i == 0 {
+ buf.str = append(buf.str, 0x80)
+ } else if i < 128 {
+ // fits single byte
+ buf.str = append(buf.str, byte(i))
+ } else {
+ s := putint(buf.sizebuf[1:], i)
+ buf.sizebuf[0] = 0x80 + byte(s)
+ buf.str = append(buf.str, buf.sizebuf[:s+1]...)
+ }
+}
+
+func (buf *encBuffer) writeBytes(b []byte) {
+ if len(b) == 1 && b[0] <= 0x7F {
+ // fits single byte, no string header
+ buf.str = append(buf.str, b[0])
+ } else {
+ buf.encodeStringHeader(len(b))
+ buf.str = append(buf.str, b...)
+ }
+}
+
+func (buf *encBuffer) writeString(s string) {
+ buf.writeBytes([]byte(s))
+}
+
+// wordBytes is the number of bytes in a big.Word
+const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8
+
+// writeBigInt writes i as an integer.
+func (buf *encBuffer) writeBigInt(i *big.Int) {
+ bitlen := i.BitLen()
+ if bitlen <= 64 {
+ buf.writeUint64(i.Uint64())
+
+ return
+ }
+
+ // Integer is larger than 64 bits, encode from i.Bits().
+ // The minimal byte length is bitlen rounded up to the next
+ // multiple of 8, divided by 8.
+ length := ((bitlen + 7) & -8) >> 3
+ buf.encodeStringHeader(length)
+ buf.str = append(buf.str, make([]byte, length)...)
+ index := length
+ sbuf := buf.str[len(buf.str)-length:]
+
+ for _, d := range i.Bits() {
+ for j := 0; j < wordBytes && index > 0; j++ {
+ index--
+
+ sbuf[index] = byte(d)
+
+ d >>= 8
+ }
+ }
+}
+
+// list adds a new list header to the header stack. It returns the index of the header.
+// Call listEnd with this index after encoding the content of the list.
+func (buf *encBuffer) list() int {
+ buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize})
+
+ return len(buf.lheads) - 1
+}
+
+func (buf *encBuffer) listEnd(index int) {
+ lh := &buf.lheads[index]
+
+ lh.size = buf.size() - lh.offset - lh.size
+ if lh.size < 56 {
+ buf.lhsize++ // length encoded into kind tag
+ } else {
+ buf.lhsize += 1 + intsize(uint64(lh.size))
+ }
+}
+
+func (buf *encBuffer) encode(val interface{}) error {
+ rval := reflect.ValueOf(val)
+
+ writer, err := cachedWriter(rval.Type())
+ if err != nil {
+ return err
+ }
+
+ return writer(rval, buf)
+}
+
+func (buf *encBuffer) encodeStringHeader(size int) {
+ if size < 56 {
+ buf.str = append(buf.str, 0x80+byte(size))
+ } else {
+ sizesize := putint(buf.sizebuf[1:], uint64(size))
+ buf.sizebuf[0] = 0xB7 + byte(sizesize)
+ buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...)
+ }
+}
+
+// encReader is the io.Reader returned by EncodeToReader.
+// It releases its encbuf at EOF.
+type encReader struct {
+ buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF.
+ lhpos int // index of list header that we're reading
+ strpos int // current position in string buffer
+ piece []byte // next piece to be read
+}
+
+func (r *encReader) Read(b []byte) (n int, err error) {
+ for {
+ if r.piece = r.next(); r.piece == nil {
+ // Put the encode buffer back into the pool at EOF when it
+ // is first encountered. Subsequent calls still return EOF
+ // as the error but the buffer is no longer valid.
+ if r.buf != nil {
+ encBufferPool.Put(r.buf)
+ r.buf = nil
+ }
+
+ return n, io.EOF
+ }
+
+ nn := copy(b[n:], r.piece)
+ n += nn
+
+ if nn < len(r.piece) {
+ // piece didn't fit, see you next time.
+ r.piece = r.piece[nn:]
+
+ return n, nil
+ }
+
+ r.piece = nil
+ }
+}
+
+// next returns the next piece of data to be read.
+// it returns nil at EOF.
+func (r *encReader) next() []byte {
+ switch {
+ case r.buf == nil:
+ return nil
+
+ case r.piece != nil:
+ // There is still data available for reading.
+ return r.piece
+
+ case r.lhpos < len(r.buf.lheads):
+ // We're before the last list header.
+ head := r.buf.lheads[r.lhpos]
+
+ sizebefore := head.offset - r.strpos
+ if sizebefore > 0 {
+ // String data before header.
+ p := r.buf.str[r.strpos:head.offset]
+ r.strpos += sizebefore
+
+ return p
+ }
+
+ r.lhpos++
+
+ return head.encode(r.buf.sizebuf[:])
+
+ case r.strpos < len(r.buf.str):
+ // String data at the end, after all list headers.
+ p := r.buf.str[r.strpos:]
+ r.strpos = len(r.buf.str)
+
+ return p
+
+ default:
+ return nil
+ }
+}
+
+func encBufferFromWriter(w io.Writer) *encBuffer {
+ switch w := w.(type) {
+ case EncoderBuffer:
+ return w.buf
+ case *EncoderBuffer:
+ return w.buf
+ case *encBuffer:
+ return w
+ default:
+ return nil
+ }
+}
+
+// EncoderBuffer is a buffer for incremental encoding.
+//
+// The zero value is NOT ready for use. To get a usable buffer,
+// create it using NewEncoderBuffer or call Reset.
+type EncoderBuffer struct {
+ buf *encBuffer
+ dst io.Writer
+
+ ownBuffer bool
+}
+
+// NewEncoderBuffer creates an encoder buffer.
+func NewEncoderBuffer(dst io.Writer) EncoderBuffer {
+ var w EncoderBuffer
+
+ w.Reset(dst)
+
+ return w
+}
+
+// Reset truncates the buffer and sets the output destination.
+func (w *EncoderBuffer) Reset(dst io.Writer) {
+ if w.buf != nil && !w.ownBuffer {
+ panic("can't Reset derived EncoderBuffer")
+ }
+
+ // If the destination writer has an *encBuffer, use it.
+ // Note that w.ownBuffer is left false here.
+ if dst != nil {
+ if outer := encBufferFromWriter(dst); outer != nil {
+ *w = EncoderBuffer{outer, nil, false}
+
+ return
+ }
+ }
+
+ // Get a fresh buffer.
+ if w.buf == nil {
+ w.buf, _ = encBufferPool.Get().(*encBuffer)
+ w.ownBuffer = true
+ }
+
+ w.buf.reset()
+ w.dst = dst
+}
+
+// Flush writes encoded RLP data to the output writer. This can only be called once.
+// If you want to re-use the buffer after Flush, you must call Reset.
+func (w *EncoderBuffer) Flush() error {
+ var err error
+
+ if w.dst != nil {
+ err = w.buf.writeTo(w.dst)
+ }
+ // Release the internal buffer.
+ if w.ownBuffer {
+ encBufferPool.Put(w.buf)
+ }
+
+ *w = EncoderBuffer{}
+
+ return err
+}
+
+// ToBytes returns the encoded bytes.
+func (w *EncoderBuffer) ToBytes() []byte {
+ return w.buf.makeBytes()
+}
+
+// AppendToBytes appends the encoded bytes to dst.
+func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte {
+ size := w.buf.size()
+ out := append(dst, make([]byte, size)...)
+ w.buf.copyTo(out[len(dst):])
+
+ return out
+}
+
+// Write appends b directly to the encoder output.
+func (w EncoderBuffer) Write(b []byte) (int, error) {
+ return w.buf.Write(b)
+}
+
+// WriteBool writes b as the integer 0 (false) or 1 (true).
+func (w EncoderBuffer) WriteBool(b bool) {
+ w.buf.writeBool(b)
+}
+
+// WriteUint64 encodes an unsigned integer.
+func (w EncoderBuffer) WriteUint64(i uint64) {
+ w.buf.writeUint64(i)
+}
+
+// WriteBigInt encodes a big.Int as an RLP string.
+// Note: Unlike with Encode, the sign of i is ignored.
+func (w EncoderBuffer) WriteBigInt(i *big.Int) {
+ w.buf.writeBigInt(i)
+}
+
+// WriteBytes encodes b as an RLP string.
+func (w EncoderBuffer) WriteBytes(b []byte) {
+ w.buf.writeBytes(b)
+}
+
+// WriteString encodes s as an RLP string.
+func (w EncoderBuffer) WriteString(s string) {
+ w.buf.writeString(s)
+}
+
+// List starts a list. It returns an internal index. Call EndList with
+// this index after encoding the content to finish the list.
+func (w EncoderBuffer) List() int {
+ return w.buf.list()
+}
+
+// ListEnd finishes the given list.
+func (w EncoderBuffer) ListEnd(index int) {
+ w.buf.listEnd(index)
+}
diff --git a/helper/rlp/encbuffer_example_test.go b/helper/rlp/encbuffer_example_test.go
new file mode 100644
index 0000000000..1c3fc7cdb5
--- /dev/null
+++ b/helper/rlp/encbuffer_example_test.go
@@ -0,0 +1,46 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp_test
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+)
+
+func ExampleEncoderBuffer() {
+ var w bytes.Buffer
+
+ // Encode [4, [5, 6]] to w.
+ buf := rlp.NewEncoderBuffer(&w)
+ l1 := buf.List()
+ buf.WriteUint64(4)
+ l2 := buf.List()
+ buf.WriteUint64(5)
+ buf.WriteUint64(6)
+ buf.ListEnd(l2)
+ buf.ListEnd(l1)
+
+ if err := buf.Flush(); err != nil {
+ panic(err)
+ }
+
+ // Output:
+ // C404C20506
+ fmt.Printf("%X\n", w.Bytes())
+}
diff --git a/helper/rlp/encode.go b/helper/rlp/encode.go
new file mode 100644
index 0000000000..4e326fd973
--- /dev/null
+++ b/helper/rlp/encode.go
@@ -0,0 +1,536 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "reflect"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp/internal/rlpstruct"
+)
+
+var (
+ // Common encoded values.
+ // These are useful when implementing EncodeRLP.
+
+ // EmptyString is the encoding of an empty string.
+ EmptyString = []byte{0x80}
+ // EmptyList is the encoding of an empty list.
+ EmptyList = []byte{0xC0}
+)
+
+var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int")
+
+// Encoder is implemented by types that require custom
+// encoding rules or want to encode private fields.
+type Encoder interface {
+ // EncodeRLP should write the RLP encoding of its receiver to w.
+ // If the implementation is a pointer method, it may also be
+ // called for nil pointers.
+ //
+ // Implementations should generate valid RLP. The data written is
+ // not verified at the moment, but a future version might. It is
+ // recommended to write only a single value but writing multiple
+ // values or no value at all is also permitted.
+ EncodeRLP(io.Writer) error
+}
+
+// Encode writes the RLP encoding of val to w. Note that Encode may
+// perform many small writes in some cases. Consider making w
+// buffered.
+//
+// Please see package-level documentation of encoding rules.
+func Encode(w io.Writer, val interface{}) error {
+ // Optimization: reuse *encBuffer when called by EncodeRLP.
+ if buf := encBufferFromWriter(w); buf != nil {
+ return buf.encode(val)
+ }
+
+ buf := getEncBuffer()
+ defer encBufferPool.Put(buf)
+
+ if err := buf.encode(val); err != nil {
+ return err
+ }
+
+ return buf.writeTo(w)
+}
+
+// EncodeToBytes returns the RLP encoding of val.
+// Please see package-level documentation for the encoding rules.
+func EncodeToBytes(val interface{}) ([]byte, error) {
+ buf := getEncBuffer()
+ defer encBufferPool.Put(buf)
+
+ if err := buf.encode(val); err != nil {
+ return nil, err
+ }
+
+ return buf.makeBytes(), nil
+}
+
+// EncodeToReader returns a reader from which the RLP encoding of val
+// can be read. The returned size is the total size of the encoded
+// data.
+//
+// Please see the documentation of Encode for the encoding rules.
+func EncodeToReader(val interface{}) (size int, r io.Reader, err error) {
+ buf := getEncBuffer()
+ if err := buf.encode(val); err != nil {
+ encBufferPool.Put(buf)
+
+ return 0, nil, err
+ }
+ // Note: can't put the reader back into the pool here
+ // because it is held by encReader. The reader puts it
+ // back when it has been fully consumed.
+ return buf.size(), &encReader{buf: buf}, nil
+}
+
+type listhead struct {
+ offset int // index of this header in string data
+ size int // total size of encoded data (including list headers)
+}
+
+// encode writes head to the given buffer, which must be at least
+// 9 bytes long. It returns the encoded bytes.
+func (head *listhead) encode(buf []byte) []byte {
+ return buf[:puthead(buf, 0xC0, 0xF7, uint64(head.size))]
+}
+
+// headsize returns the size of a list or string header
+// for a value of the given size.
+func headsize(size uint64) int {
+ if size < 56 {
+ return 1
+ }
+
+ return 1 + intsize(size)
+}
+
+// puthead writes a list or string header to buf.
+// buf must be at least 9 bytes long.
+func puthead(buf []byte, smalltag, largetag byte, size uint64) int {
+ if size < 56 {
+ buf[0] = smalltag + byte(size)
+
+ return 1
+ }
+
+ sizesize := putint(buf[1:], size)
+ buf[0] = largetag + byte(sizesize)
+
+ return sizesize + 1
+}
+
+var encoderInterface = reflect.TypeOf(new(Encoder)).Elem()
+
+// makeWriter creates a writer function for the given type.
+func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
+ kind := typ.Kind()
+
+ switch {
+ case typ == rawValueType:
+ return writeRawValue, nil
+ case typ.AssignableTo(reflect.PtrTo(bigInt)):
+ return writeBigIntPtr, nil
+ case typ.AssignableTo(bigInt):
+ return writeBigIntNoPtr, nil
+ case kind == reflect.Ptr:
+ return makePtrWriter(typ, ts)
+ case reflect.PtrTo(typ).Implements(encoderInterface):
+ return makeEncoderWriter(typ), nil
+ case isUint(kind):
+ return writeUint, nil
+ case kind == reflect.Bool:
+ return writeBool, nil
+ case kind == reflect.String:
+ return writeString, nil
+ case kind == reflect.Slice && isByte(typ.Elem()):
+ return writeBytes, nil
+ case kind == reflect.Array && isByte(typ.Elem()):
+ return makeByteArrayWriter(typ), nil
+ case kind == reflect.Slice || kind == reflect.Array:
+ return makeSliceWriter(typ, ts)
+ case kind == reflect.Struct:
+ return makeStructWriter(typ)
+ case kind == reflect.Interface:
+ return writeInterface, nil
+ default:
+ return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ)
+ }
+}
+
+func writeRawValue(val reflect.Value, w *encBuffer) error {
+ w.str = append(w.str, val.Bytes()...)
+
+ return nil
+}
+
+func writeUint(val reflect.Value, w *encBuffer) error {
+ w.writeUint64(val.Uint())
+
+ return nil
+}
+
+func writeBool(val reflect.Value, w *encBuffer) error {
+ w.writeBool(val.Bool())
+
+ return nil
+}
+
+func writeBigIntPtr(val reflect.Value, w *encBuffer) error {
+ ptr, _ := val.Interface().(*big.Int)
+ if ptr == nil {
+ w.str = append(w.str, 0x80)
+
+ return nil
+ }
+
+ if ptr.Sign() == -1 {
+ return ErrNegativeBigInt
+ }
+
+ w.writeBigInt(ptr)
+
+ return nil
+}
+
+func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error {
+ i, _ := val.Interface().(big.Int)
+ if i.Sign() == -1 {
+ return ErrNegativeBigInt
+ }
+
+ w.writeBigInt(&i)
+
+ return nil
+}
+
+func writeBytes(val reflect.Value, w *encBuffer) error {
+ w.writeBytes(val.Bytes())
+
+ return nil
+}
+
+func makeByteArrayWriter(typ reflect.Type) writer {
+ switch typ.Len() {
+ case 0:
+ return writeLengthZeroByteArray
+ case 1:
+ return writeLengthOneByteArray
+ default:
+ length := typ.Len()
+
+ return func(val reflect.Value, w *encBuffer) error {
+ if !val.CanAddr() {
+ // Getting the byte slice of val requires it to be addressable. Make it
+ // addressable by copying.
+ cp := reflect.New(val.Type()).Elem()
+ cp.Set(val)
+ val = cp
+ }
+
+ slice := byteArrayBytes(val, length)
+ w.encodeStringHeader(len(slice))
+ w.str = append(w.str, slice...)
+
+ return nil
+ }
+ }
+}
+
+func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error {
+ w.str = append(w.str, 0x80)
+
+ return nil
+}
+
+func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error {
+ b := byte(val.Index(0).Uint())
+ if b <= 0x7f {
+ w.str = append(w.str, b)
+ } else {
+ w.str = append(w.str, 0x81, b)
+ }
+
+ return nil
+}
+
+func writeString(val reflect.Value, w *encBuffer) error {
+ s := val.String()
+ if len(s) == 1 && s[0] <= 0x7f {
+ // fits single byte, no string header
+ w.str = append(w.str, s[0])
+ } else {
+ w.encodeStringHeader(len(s))
+ w.str = append(w.str, s...)
+ }
+
+ return nil
+}
+
+func writeInterface(val reflect.Value, w *encBuffer) error {
+ if val.IsNil() {
+ // Write empty list. This is consistent with the previous RLP
+ // encoder that we had and should therefore avoid any
+ // problems.
+ w.str = append(w.str, 0xC0)
+
+ return nil
+ }
+
+ eval := val.Elem()
+
+ writer, err := cachedWriter(eval.Type())
+ if err != nil {
+ return err
+ }
+
+ return writer(eval, w)
+}
+
+func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
+ etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{})
+ if etypeinfo.writerErr != nil {
+ return nil, etypeinfo.writerErr
+ }
+
+ var wfn writer
+
+ if ts.Tail {
+ // This is for struct tail slices.
+ // w.list is not called for them.
+ wfn = func(val reflect.Value, w *encBuffer) error {
+ vlen := val.Len()
+ for i := 0; i < vlen; i++ {
+ if err := etypeinfo.writer(val.Index(i), w); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ }
+ } else {
+ // This is for regular slices and arrays.
+ wfn = func(val reflect.Value, w *encBuffer) error {
+ vlen := val.Len()
+ if vlen == 0 {
+ w.str = append(w.str, 0xC0)
+
+ return nil
+ }
+
+ listOffset := w.list()
+
+ for i := 0; i < vlen; i++ {
+ if err := etypeinfo.writer(val.Index(i), w); err != nil {
+ return err
+ }
+ }
+
+ w.listEnd(listOffset)
+
+ return nil
+ }
+ }
+
+ return wfn, nil
+}
+
+func makeStructWriter(typ reflect.Type) (writer, error) {
+ fields, err := structFields(typ)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, f := range fields {
+ if f.info.writerErr != nil {
+ return nil, structFieldError{typ, f.index, f.info.writerErr}
+ }
+ }
+
+ var writer writer
+
+ firstOptionalField := firstOptionalField(fields)
+ if firstOptionalField == len(fields) {
+ // This is the writer function for structs without any optional fields.
+ writer = func(val reflect.Value, w *encBuffer) error {
+ lh := w.list()
+
+ for _, f := range fields {
+ if err := f.info.writer(val.Field(f.index), w); err != nil {
+ return err
+ }
+ }
+
+ w.listEnd(lh)
+
+ return nil
+ }
+ } else {
+ // If there are any "optional" fields, the writer needs to perform additional
+ // checks to determine the output list length.
+ writer = func(val reflect.Value, w *encBuffer) error {
+ lastField := len(fields) - 1
+ for ; lastField >= firstOptionalField; lastField-- {
+ if !val.Field(fields[lastField].index).IsZero() {
+ break
+ }
+ }
+
+ lh := w.list()
+ for i := 0; i <= lastField; i++ {
+ if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil {
+ return err
+ }
+ }
+
+ w.listEnd(lh)
+
+ return nil
+ }
+ }
+
+ return writer, nil
+}
+
+func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) {
+ nilEncoding := byte(0xC0)
+ if typeNilKind(typ.Elem(), ts) == String {
+ nilEncoding = 0x80
+ }
+
+ etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{})
+ if etypeinfo.writerErr != nil {
+ return nil, etypeinfo.writerErr
+ }
+
+ writer := func(val reflect.Value, w *encBuffer) error {
+ if ev := val.Elem(); ev.IsValid() {
+ return etypeinfo.writer(ev, w)
+ }
+
+ w.str = append(w.str, nilEncoding)
+
+ return nil
+ }
+
+ return writer, nil
+}
+
+func makeEncoderWriter(typ reflect.Type) writer {
+ if typ.Implements(encoderInterface) {
+ return func(val reflect.Value, w *encBuffer) error {
+ //nolint:forcetypeassert
+ return val.Interface().(Encoder).EncodeRLP(w)
+ }
+ }
+
+ w := func(val reflect.Value, w *encBuffer) error {
+ if !val.CanAddr() {
+ // package json simply doesn't call MarshalJSON for this case, but encodes the
+ // value as if it didn't implement the interface. We don't want to handle it that
+ // way.
+ return fmt.Errorf("rlp: unadressable value of type %v, EncodeRLP is pointer method", val.Type())
+ }
+
+ //nolint:forcetypeassert
+ return val.Addr().Interface().(Encoder).EncodeRLP(w)
+ }
+
+ return w
+}
+
+// putint writes i to the beginning of b in big endian byte
+// order, using the least number of bytes needed to represent i.
+func putint(b []byte, i uint64) (size int) {
+ switch {
+ case i < (1 << 8):
+ b[0] = byte(i)
+
+ return 1
+ case i < (1 << 16):
+ b[0] = byte(i >> 8)
+ b[1] = byte(i)
+
+ return 2
+ case i < (1 << 24):
+ b[0] = byte(i >> 16)
+ b[1] = byte(i >> 8)
+ b[2] = byte(i)
+
+ return 3
+ case i < (1 << 32):
+ b[0] = byte(i >> 24)
+ b[1] = byte(i >> 16)
+ b[2] = byte(i >> 8)
+ b[3] = byte(i)
+
+ return 4
+ case i < (1 << 40):
+ b[0] = byte(i >> 32)
+ b[1] = byte(i >> 24)
+ b[2] = byte(i >> 16)
+ b[3] = byte(i >> 8)
+ b[4] = byte(i)
+
+ return 5
+ case i < (1 << 48):
+ b[0] = byte(i >> 40)
+ b[1] = byte(i >> 32)
+ b[2] = byte(i >> 24)
+ b[3] = byte(i >> 16)
+ b[4] = byte(i >> 8)
+ b[5] = byte(i)
+
+ return 6
+ case i < (1 << 56):
+ b[0] = byte(i >> 48)
+ b[1] = byte(i >> 40)
+ b[2] = byte(i >> 32)
+ b[3] = byte(i >> 24)
+ b[4] = byte(i >> 16)
+ b[5] = byte(i >> 8)
+ b[6] = byte(i)
+
+ return 7
+ default:
+ b[0] = byte(i >> 56)
+ b[1] = byte(i >> 48)
+ b[2] = byte(i >> 40)
+ b[3] = byte(i >> 32)
+ b[4] = byte(i >> 24)
+ b[5] = byte(i >> 16)
+ b[6] = byte(i >> 8)
+ b[7] = byte(i)
+
+ return 8
+ }
+}
+
+// intsize computes the minimum number of bytes required to store i.
+func intsize(i uint64) (size int) {
+ for size = 1; ; size++ {
+ if i >>= 8; i == 0 {
+ return size
+ }
+ }
+}
diff --git a/helper/rlp/encode_test.go b/helper/rlp/encode_test.go
new file mode 100644
index 0000000000..0e9ab1e5c3
--- /dev/null
+++ b/helper/rlp/encode_test.go
@@ -0,0 +1,624 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+type testEncoder struct {
+ err error
+}
+
+func (e *testEncoder) EncodeRLP(w io.Writer) error {
+ if e == nil {
+ panic("EncodeRLP called on nil value")
+ }
+
+ if e.err != nil {
+ return e.err
+ }
+
+ w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1})
+
+ return nil
+}
+
+type testEncoderValueMethod struct{}
+
+func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error {
+ w.Write([]byte{0xFA, 0xFE, 0xF0})
+
+ return nil
+}
+
+type byteEncoder byte
+
+func (e byteEncoder) EncodeRLP(w io.Writer) error {
+ w.Write(EmptyList)
+
+ return nil
+}
+
+type undecodableEncoder func()
+
+func (f undecodableEncoder) EncodeRLP(w io.Writer) error {
+ w.Write([]byte{0xF5, 0xF5, 0xF5})
+
+ return nil
+}
+
+type encodableReader struct {
+ A, B uint
+}
+
+func (e *encodableReader) Read(b []byte) (int, error) {
+ panic("called")
+}
+
+type namedByteType byte
+
+var (
+ _ = Encoder(&testEncoder{})
+ _ = Encoder(byteEncoder(0))
+
+ reader io.Reader = &encodableReader{1, 2}
+)
+
+type encTest struct {
+ val interface{}
+ output, error string
+}
+
+var encTests = []encTest{
+ // booleans
+ {val: true, output: "01"},
+ {val: false, output: "80"},
+
+ // integers
+ {val: uint32(0), output: "80"},
+ {val: uint32(127), output: "7F"},
+ {val: uint32(128), output: "8180"},
+ {val: uint32(256), output: "820100"},
+ {val: uint32(1024), output: "820400"},
+ {val: uint32(0xFFFFFF), output: "83FFFFFF"},
+ {val: uint32(0xFFFFFFFF), output: "84FFFFFFFF"},
+ {val: uint64(0xFFFFFFFF), output: "84FFFFFFFF"},
+ {val: uint64(0xFFFFFFFFFF), output: "85FFFFFFFFFF"},
+ {val: uint64(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"},
+ {val: uint64(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"},
+ {val: uint64(0xFFFFFFFFFFFFFFFF), output: "88FFFFFFFFFFFFFFFF"},
+
+ // big integers (should match uint for small values)
+ {val: big.NewInt(0), output: "80"},
+ {val: big.NewInt(1), output: "01"},
+ {val: big.NewInt(127), output: "7F"},
+ {val: big.NewInt(128), output: "8180"},
+ {val: big.NewInt(256), output: "820100"},
+ {val: big.NewInt(1024), output: "820400"},
+ {val: big.NewInt(0xFFFFFF), output: "83FFFFFF"},
+ {val: big.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"},
+ {val: big.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"},
+ {val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"},
+ {val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"},
+ {
+ val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")),
+ output: "8F102030405060708090A0B0C0D0E0F2",
+ },
+ {
+ val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")),
+ output: "9C0100020003000400050006000700080009000A000B000C000D000E01",
+ },
+ {
+ val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")),
+ output: "A1010000000000000000000000000000000000000000000000000000000000000000",
+ },
+ {
+ val: veryBigInt,
+ output: "89FFFFFFFFFFFFFFFFFF",
+ },
+ {
+ val: veryVeryBigInt,
+ output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001",
+ },
+
+ // non-pointer big.Int
+ {val: *big.NewInt(0), output: "80"},
+ {val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"},
+
+ // negative ints are not supported
+ {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
+ {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"},
+
+ // byte arrays
+ {val: [0]byte{}, output: "80"},
+ {val: [1]byte{0}, output: "00"},
+ {val: [1]byte{1}, output: "01"},
+ {val: [1]byte{0x7F}, output: "7F"},
+ {val: [1]byte{0x80}, output: "8180"},
+ {val: [1]byte{0xFF}, output: "81FF"},
+ {val: [3]byte{1, 2, 3}, output: "83010203"},
+ {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
+
+ // named byte type arrays
+ {val: [0]namedByteType{}, output: "80"},
+ {val: [1]namedByteType{0}, output: "00"},
+ {val: [1]namedByteType{1}, output: "01"},
+ {val: [1]namedByteType{0x7F}, output: "7F"},
+ {val: [1]namedByteType{0x80}, output: "8180"},
+ {val: [1]namedByteType{0xFF}, output: "81FF"},
+ {val: [3]namedByteType{1, 2, 3}, output: "83010203"},
+ {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"},
+
+ // byte slices
+ {val: []byte{}, output: "80"},
+ {val: []byte{0}, output: "00"},
+ {val: []byte{0x7E}, output: "7E"},
+ {val: []byte{0x7F}, output: "7F"},
+ {val: []byte{0x80}, output: "8180"},
+ {val: []byte{1, 2, 3}, output: "83010203"},
+
+ // named byte type slices
+ {val: []namedByteType{}, output: "80"},
+ {val: []namedByteType{0}, output: "00"},
+ {val: []namedByteType{0x7E}, output: "7E"},
+ {val: []namedByteType{0x7F}, output: "7F"},
+ {val: []namedByteType{0x80}, output: "8180"},
+ {val: []namedByteType{1, 2, 3}, output: "83010203"},
+
+ // strings
+ {val: "", output: "80"},
+ {val: "\x7E", output: "7E"},
+ {val: "\x7F", output: "7F"},
+ {val: "\x80", output: "8180"},
+ {val: "dog", output: "83646F67"},
+ {
+ val: "Lorem ipsum dolor sit amet, consectetur adipisicing eli",
+ output: "B74C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C69",
+ },
+ {
+ val: "Lorem ipsum dolor sit amet, consectetur adipisicing elit",
+ output: "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974",
+ },
+ {
+ val: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur mauris magna, suscipit sed vehicula non, iaculis faucibus tortor. Proin suscipit ultricies malesuada. Duis tortor elit, dictum quis tristique eu, ultrices at risus. Morbi a est imperdiet mi ullamcorper aliquet suscipit nec lorem. Aenean quis leo mollis, vulputate elit varius, consequat enim. Nulla ultrices turpis justo, et posuere urna consectetur nec. Proin non convallis metus. Donec tempor ipsum in mauris congue sollicitudin. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Suspendisse convallis sem vel massa faucibus, eget lacinia lacus tempor. Nulla quis ultricies purus. Proin auctor rhoncus nibh condimentum mollis. Aliquam consequat enim at metus luctus, a eleifend purus egestas. Curabitur at nibh metus. Nam bibendum, neque at auctor tristique, lorem libero aliquet arcu, non interdum tellus lectus sit amet eros. Cras rhoncus, metus ac ornare cursus, dolor justo ultrices metus, at ullamcorper volutpat",
+ output: "B904004C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E73656374657475722061646970697363696E6720656C69742E20437572616269747572206D6175726973206D61676E612C20737573636970697420736564207665686963756C61206E6F6E2C20696163756C697320666175636962757320746F72746F722E2050726F696E20737573636970697420756C74726963696573206D616C6573756164612E204475697320746F72746F7220656C69742C2064696374756D2071756973207472697374697175652065752C20756C7472696365732061742072697375732E204D6F72626920612065737420696D70657264696574206D6920756C6C616D636F7270657220616C6971756574207375736369706974206E6563206C6F72656D2E2041656E65616E2071756973206C656F206D6F6C6C69732C2076756C70757461746520656C6974207661726975732C20636F6E73657175617420656E696D2E204E756C6C6120756C74726963657320747572706973206A7573746F2C20657420706F73756572652075726E6120636F6E7365637465747572206E65632E2050726F696E206E6F6E20636F6E76616C6C6973206D657475732E20446F6E65632074656D706F7220697073756D20696E206D617572697320636F6E67756520736F6C6C696369747564696E2E20566573746962756C756D20616E746520697073756D207072696D697320696E206661756369627573206F726369206C756374757320657420756C74726963657320706F737565726520637562696C69612043757261653B2053757370656E646973736520636F6E76616C6C69732073656D2076656C206D617373612066617563696275732C2065676574206C6163696E6961206C616375732074656D706F722E204E756C6C61207175697320756C747269636965732070757275732E2050726F696E20617563746F722072686F6E637573206E69626820636F6E64696D656E74756D206D6F6C6C69732E20416C697175616D20636F6E73657175617420656E696D206174206D65747573206C75637475732C206120656C656966656E6420707572757320656765737461732E20437572616269747572206174206E696268206D657475732E204E616D20626962656E64756D2C206E6571756520617420617563746F72207472697374697175652C206C6F72656D206C696265726F20616C697175657420617263752C206E6F6E20696E74657264756D2074656C6C7573206C65637475732073697420616D65742065726F732E20437261732072686F6E6375732C206D65747573206163206F726E617265206375727375732C20646F6C6F72206A7573746F20756C747269636573206D657475732C20617420756C6C616D636F7270657220766F6C7574706174",
+ },
+
+ // slices
+ {val: []uint{}, output: "C0"},
+ {val: []uint{1, 2, 3}, output: "C3010203"},
+ {
+ // [ [], [[]], [ [], [[]] ] ]
+ val: []interface{}{[]interface{}{}, [][]interface{}{{}}, []interface{}{[]interface{}{}, [][]interface{}{{}}}},
+ output: "C7C0C1C0C3C0C1C0",
+ },
+ {
+ val: []string{"aaa", "bbb", "ccc", "ddd", "eee", "fff", "ggg", "hhh", "iii", "jjj", "kkk", "lll", "mmm", "nnn", "ooo"},
+ output: "F83C836161618362626283636363836464648365656583666666836767678368686883696969836A6A6A836B6B6B836C6C6C836D6D6D836E6E6E836F6F6F",
+ },
+ {
+ val: []interface{}{uint(1), uint(0xFFFFFF), []interface{}{[]uint{4, 5, 5}}, "abc"},
+ output: "CE0183FFFFFFC4C304050583616263",
+ },
+ {
+ val: [][]string{
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ {"asdf", "qwer", "zxcv"},
+ },
+ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376",
+ },
+
+ // RawValue
+ {val: RawValue(unhex("01")), output: "01"},
+ {val: RawValue(unhex("82FFFF")), output: "82FFFF"},
+ {val: []RawValue{unhex("01"), unhex("02")}, output: "C20102"},
+
+ // structs
+ {val: simplestruct{}, output: "C28080"},
+ {val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"},
+ {val: &recstruct{5, nil}, output: "C205C0"},
+ {val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"},
+ {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"},
+
+ // struct tag "-"
+ {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"},
+
+ // struct tag "tail"
+ {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"},
+ {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"},
+ {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"},
+ {val: &tailRaw{A: 1, Tail: nil}, output: "C101"},
+
+ // struct tag "optional"
+ {val: &optionalFields{}, output: "C180"},
+ {val: &optionalFields{A: 1}, output: "C101"},
+ {val: &optionalFields{A: 1, B: 2}, output: "C20102"},
+ {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"},
+ {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"},
+ {val: &optionalAndTailField{A: 1}, output: "C101"},
+ {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"},
+ {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
+ {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"},
+ {val: &optionalBigIntField{A: 1}, output: "C101"},
+ {val: &optionalPtrField{A: 1}, output: "C101"},
+ {val: &optionalPtrFieldNil{A: 1}, output: "C101"},
+ {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"},
+ {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"},
+ {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail
+ {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail
+
+ // nil
+ {val: (*uint)(nil), output: "80"},
+ {val: (*string)(nil), output: "80"},
+ {val: (*[]byte)(nil), output: "80"},
+ {val: (*[10]byte)(nil), output: "80"},
+ {val: (*big.Int)(nil), output: "80"},
+ {val: (*[]string)(nil), output: "C0"},
+ {val: (*[10]string)(nil), output: "C0"},
+ {val: (*[]interface{})(nil), output: "C0"},
+ {val: (*[]struct{ uint })(nil), output: "C0"},
+ {val: (*interface{})(nil), output: "C0"},
+
+ // nil struct fields
+ {
+ val: struct {
+ X *[]byte
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *[2]byte
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *uint64
+ }{},
+ output: "C180",
+ },
+ {
+ val: struct {
+ X *uint64 `rlp:"nilList"`
+ }{},
+ output: "C1C0",
+ },
+ {
+ val: struct {
+ X *[]uint64
+ }{},
+ output: "C1C0",
+ },
+ {
+ val: struct {
+ X *[]uint64 `rlp:"nilString"`
+ }{},
+ output: "C180",
+ },
+
+ // interfaces
+ {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct
+
+ // Encoder
+ {val: (*testEncoder)(nil), output: "C0"},
+ {val: &testEncoder{}, output: "00010001000100010001"},
+ {val: &testEncoder{errors.New("test error")}, error: "test error"},
+ {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"},
+ {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"},
+
+ // Verify that the Encoder interface works for unsupported types like func().
+ {val: undecodableEncoder(func() {}), output: "F5F5F5"},
+
+ // Verify that pointer method testEncoder.EncodeRLP is called for
+ // addressable non-pointer values.
+ {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"},
+ {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"},
+
+ // Verify the error for non-addressable non-pointer Encoder.
+ {val: testEncoder{}, error: "rlp: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"},
+
+ // Verify Encoder takes precedence over []byte.
+ {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"},
+}
+
+func runEncTests(t *testing.T, f func(val interface{}) ([]byte, error)) {
+ t.Helper()
+
+ for i, test := range encTests {
+ output, err := f(test.val)
+ if err != nil && test.error == "" {
+ t.Errorf("test %d: unexpected error: %v\nvalue %#v\ntype %T",
+ i, err, test.val, test.val)
+
+ continue
+ }
+
+ if test.error != "" && fmt.Sprint(err) != test.error {
+ t.Errorf("test %d: error mismatch\ngot %v\nwant %v\nvalue %#v\ntype %T",
+ i, err, test.error, test.val, test.val)
+
+ continue
+ }
+
+ if err == nil && !bytes.Equal(output, unhex(test.output)) {
+ t.Errorf("test %d: output mismatch:\ngot %X\nwant %s\nvalue %#v\ntype %T",
+ i, output, test.output, test.val, test.val)
+ }
+ }
+}
+
+func TestEncode(t *testing.T) {
+ runEncTests(t, func(val interface{}) ([]byte, error) {
+ b := new(bytes.Buffer)
+ err := Encode(b, val)
+
+ return b.Bytes(), err
+ })
+}
+
+func TestEncodeToBytes(t *testing.T) {
+ runEncTests(t, EncodeToBytes)
+}
+
+func TestEncodeAppendToBytes(t *testing.T) {
+ buffer := make([]byte, 20)
+
+ runEncTests(t, func(val interface{}) ([]byte, error) {
+ w := NewEncoderBuffer(nil)
+ defer w.Flush()
+
+ err := Encode(w, val)
+ if err != nil {
+ return nil, err
+ }
+
+ output := w.AppendToBytes(buffer[:0])
+
+ return output, nil
+ })
+}
+
+func TestEncodeToReader(t *testing.T) {
+ runEncTests(t, func(val interface{}) ([]byte, error) {
+ _, r, err := EncodeToReader(val)
+ if err != nil {
+ return nil, err
+ }
+
+ return io.ReadAll(r)
+ })
+}
+
+func TestEncodeToReaderPiecewise(t *testing.T) {
+ runEncTests(t, func(val interface{}) ([]byte, error) {
+ size, r, err := EncodeToReader(val)
+ if err != nil {
+ return nil, err
+ }
+
+ // read output piecewise
+ output := make([]byte, size)
+ for start, end := 0, 0; start < size; start = end {
+ if remaining := size - start; remaining < 3 {
+ end += remaining
+ } else {
+ end = start + 3
+ }
+
+ n, err := r.Read(output[start:end])
+ end = start + n
+
+ if err == io.EOF {
+ break
+ } else if err != nil {
+ return nil, err
+ }
+ }
+
+ return output, nil
+ })
+}
+
+// This is a regression test verifying that encReader
+// returns its encbuf to the pool only once.
+func TestEncodeToReaderReturnToPool(t *testing.T) {
+ buf := make([]byte, 50)
+ wg := new(sync.WaitGroup)
+
+ for i := 0; i < 5; i++ {
+ wg.Add(1)
+
+ go func() {
+ for i := 0; i < 1000; i++ {
+ _, r, _ := EncodeToReader("foo")
+ io.ReadAll(r)
+ r.Read(buf)
+ r.Read(buf)
+ r.Read(buf)
+ r.Read(buf)
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+var sink interface{}
+
+func BenchmarkIntsize(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ sink = intsize(0x12345678)
+ }
+}
+
+func BenchmarkPutint(b *testing.B) {
+ buf := make([]byte, 8)
+ for i := 0; i < b.N; i++ {
+ putint(buf, 0x12345678)
+ sink = buf
+ }
+}
+
+func BenchmarkEncodeBigInts(b *testing.B) {
+ ints := make([]*big.Int, 200)
+ for i := range ints {
+ ints[i] = new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(i)), nil)
+ }
+
+ out := bytes.NewBuffer(make([]byte, 0, 4096))
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+
+ if err := Encode(out, ints); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkEncodeConcurrentInterface(b *testing.B) {
+ type struct1 struct {
+ A string
+ B *big.Int
+ C [20]byte
+ }
+
+ value := []interface{}{
+ uint(999),
+ &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)},
+ [10]byte{1, 2, 3, 4, 5, 6},
+ []string{"yeah", "yeah", "yeah"},
+ }
+
+ var wg sync.WaitGroup
+
+ for cpu := 0; cpu < runtime.NumCPU(); cpu++ {
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ var buffer bytes.Buffer
+ for i := 0; i < b.N; i++ {
+ buffer.Reset()
+
+ err := Encode(&buffer, value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+type byteArrayStruct struct {
+ A [20]byte
+ B [32]byte
+ C [32]byte
+}
+
+func BenchmarkEncodeByteArrayStruct(b *testing.B) {
+ var (
+ out bytes.Buffer
+ value byteArrayStruct
+ )
+
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+
+ if err := Encode(&out, &value); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+type structSliceElem struct {
+ X uint64
+ Y uint64
+ Z uint64
+}
+
+type structPtrSlice []*structSliceElem
+
+func BenchmarkEncodeStructPtrSlice(b *testing.B) {
+ var (
+ out bytes.Buffer
+ value = structPtrSlice{
+ &structSliceElem{1, 1, 1},
+ &structSliceElem{2, 2, 2},
+ &structSliceElem{3, 3, 3},
+ &structSliceElem{5, 5, 5},
+ &structSliceElem{6, 6, 6},
+ &structSliceElem{7, 7, 7},
+ }
+ )
+
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ out.Reset()
+
+ if err := Encode(&out, &value); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
diff --git a/helper/rlp/encoder_example_test.go b/helper/rlp/encoder_example_test.go
new file mode 100644
index 0000000000..01cab21cd6
--- /dev/null
+++ b/helper/rlp/encoder_example_test.go
@@ -0,0 +1,48 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp_test
+
+import (
+ "fmt"
+ "io"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+)
+
+type MyCoolType struct {
+ Name string
+ a, b uint
+}
+
+// EncodeRLP writes x as RLP list [a, b] that omits the Name field.
+func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) {
+ return rlp.Encode(w, []uint{x.a, x.b})
+}
+
+func ExampleEncoder() {
+ var t *MyCoolType // t is nil pointer to MyCoolType
+ bytes, _ := rlp.EncodeToBytes(t)
+ fmt.Printf("%v → %X\n", t, bytes)
+
+ t = &MyCoolType{Name: "foobar", a: 5, b: 6}
+ bytes, _ = rlp.EncodeToBytes(t)
+ fmt.Printf("%v → %X\n", t, bytes)
+
+ // Output:
+ // → C0
+ // &{foobar 5 6} → C20506
+}
diff --git a/helper/rlp/internal/rlpstruct/rlpstruct.go b/helper/rlp/internal/rlpstruct/rlpstruct.go
new file mode 100644
index 0000000000..b492d0ab85
--- /dev/null
+++ b/helper/rlp/internal/rlpstruct/rlpstruct.go
@@ -0,0 +1,238 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package rlpstruct implements struct processing for RLP encoding/decoding.
+//
+// In particular, this package handles all rules around field filtering,
+// struct tags and nil value determination.
+package rlpstruct
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// Field represents a struct field.
+type Field struct {
+ Name string
+ Index int
+ Exported bool
+ Type Type
+ Tag string
+}
+
+// Type represents the attributes of a Go type.
+type Type struct {
+ Name string
+ Kind reflect.Kind
+ IsEncoder bool // whether type implements rlp.Encoder
+ IsDecoder bool // whether type implements rlp.Decoder
+ Elem *Type // non-nil for Kind values of Ptr, Slice, Array
+}
+
+// DefaultNilValue determines whether a nil pointer to t encodes/decodes
+// as an empty string or empty list.
+func (t Type) DefaultNilValue() NilKind {
+ k := t.Kind
+ if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) {
+ return NilKindString
+ }
+
+ return NilKindList
+}
+
+// NilKind is the RLP value encoded in place of nil pointers.
+type NilKind uint8
+
+const (
+ NilKindString NilKind = 0x80
+ NilKindList NilKind = 0xC0
+)
+
+// Tags represents struct tags.
+type Tags struct {
+ // rlp:"nil" controls whether empty input results in a nil pointer.
+ // nilKind is the kind of empty value allowed for the field.
+ NilKind NilKind
+ NilOK bool
+
+ // rlp:"optional" allows for a field to be missing in the input list.
+ // If this is set, all subsequent fields must also be optional.
+ Optional bool
+
+ // rlp:"tail" controls whether this field swallows additional list elements. It can
+ // only be set for the last field, which must be of slice type.
+ Tail bool
+
+ // rlp:"-" ignores fields.
+ Ignored bool
+}
+
+// TagError is raised for invalid struct tags.
+type TagError struct {
+ StructType string
+
+ // These are set by this package.
+ Field string
+ Tag string
+ Err string
+}
+
+func (e TagError) Error() string {
+ field := "field " + e.Field
+
+ if e.StructType != "" {
+ field = e.StructType + "." + e.Field
+ }
+
+ return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err)
+}
+
+// ProcessFields filters the given struct fields, returning only fields
+// that should be considered for encoding/decoding.
+func ProcessFields(allFields []Field) ([]Field, []Tags, error) {
+ var (
+ // Gather all exported fields and their tags.
+ fields = make([]Field, 0, len(allFields))
+ tags = make([]Tags, 0, len(allFields))
+ lastPublic = lastPublicField(allFields)
+ )
+
+ for _, field := range allFields {
+ if !field.Exported {
+ continue
+ }
+
+ ts, err := parseTag(field, lastPublic)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if ts.Ignored {
+ continue
+ }
+
+ fields = append(fields, field)
+ tags = append(tags, ts)
+ }
+
+ // Verify optional field consistency. If any optional field exists,
+ // all fields after it must also be optional. Note: optional + tail
+ // is supported.
+ var (
+ anyOptional bool
+ firstOptionalName string
+ )
+
+ for i, ts := range tags {
+ name := fields[i].Name
+
+ if ts.Optional || ts.Tail {
+ if !anyOptional {
+ firstOptionalName = name
+ }
+
+ anyOptional = true
+ } else {
+ if anyOptional {
+ msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName)
+
+ return nil, nil, TagError{Field: name, Err: msg}
+ }
+ }
+ }
+
+ return fields, tags, nil
+}
+
+func parseTag(field Field, lastPublic int) (Tags, error) {
+ name := field.Name
+ tag := reflect.StructTag(field.Tag)
+
+ var ts Tags
+
+ for _, t := range strings.Split(tag.Get("rlp"), ",") {
+ switch t = strings.TrimSpace(t); t {
+ case "":
+ // empty tag is allowed for some reason
+ case "-":
+ ts.Ignored = true
+ case "nil", "nilString", "nilList":
+ ts.NilOK = true
+
+ if field.Type.Kind != reflect.Ptr {
+ return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"}
+ }
+
+ switch t {
+ case "nil":
+ ts.NilKind = field.Type.Elem.DefaultNilValue()
+ case "nilString":
+ ts.NilKind = NilKindString
+ case "nilList":
+ ts.NilKind = NilKindList
+ }
+ case "optional":
+ ts.Optional = true
+ if ts.Tail {
+ return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`}
+ }
+ case "tail":
+ ts.Tail = true
+
+ if field.Index != lastPublic {
+ return ts, TagError{Field: name, Tag: t, Err: "must be on last field"}
+ }
+
+ if ts.Optional {
+ return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`}
+ }
+
+ if field.Type.Kind != reflect.Slice {
+ return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"}
+ }
+ default:
+ return ts, TagError{Field: name, Tag: t, Err: "unknown tag"}
+ }
+ }
+
+ return ts, nil
+}
+
+func lastPublicField(fields []Field) int {
+ last := 0
+
+ for _, f := range fields {
+ if f.Exported {
+ last = f.Index
+ }
+ }
+
+ return last
+}
+
+func isUint(k reflect.Kind) bool {
+ return k >= reflect.Uint && k <= reflect.Uintptr
+}
+
+func isByte(typ Type) bool {
+ return typ.Kind == reflect.Uint8 && !typ.IsEncoder
+}
+
+func isByteArray(typ Type) bool {
+ return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem)
+}
diff --git a/helper/rlp/iterator.go b/helper/rlp/iterator.go
new file mode 100644
index 0000000000..0d144f0d4d
--- /dev/null
+++ b/helper/rlp/iterator.go
@@ -0,0 +1,65 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+type listIterator struct {
+ data []byte
+ next []byte
+ err error
+}
+
+// NewListIterator creates an iterator for the (list) represented by data
+// TODO: Consider removing this implementation, as it is no longer used.
+func NewListIterator(data RawValue) (*listIterator, error) {
+ k, t, c, err := readKind(data)
+ if err != nil {
+ return nil, err
+ }
+
+ if k != List {
+ return nil, ErrExpectedList
+ }
+
+ it := &listIterator{
+ data: data[t : t+c],
+ }
+
+ return it, nil
+}
+
+// Next forwards the iterator one step, returns true if it was not at end yet
+func (it *listIterator) Next() bool {
+ if len(it.data) == 0 {
+ return false
+ }
+
+ _, t, c, err := readKind(it.data)
+ it.next = it.data[:t+c]
+ it.data = it.data[t+c:]
+ it.err = err
+
+ return true
+}
+
+// Value returns the current value
+func (it *listIterator) Value() []byte {
+ return it.next
+}
+
+func (it *listIterator) Err() error {
+ return it.err
+}
diff --git a/helper/rlp/iterator_test.go b/helper/rlp/iterator_test.go
new file mode 100644
index 0000000000..551e8b6d64
--- /dev/null
+++ b/helper/rlp/iterator_test.go
@@ -0,0 +1,70 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "encoding/hex"
+ "testing"
+)
+
+// TestIterator tests some basic things about the ListIterator. A more
+// comprehensive test can be found in core/rlp_test.go, where we can
+// use both types and rlp without dependency cycles
+func TestIterator(t *testing.T) {
+ // use hex string without prefix "0x"
+ bodyRlpHex := "f902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000"
+
+ bodyRlp, err := hex.DecodeString(bodyRlpHex)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ it, err := NewListIterator(bodyRlp)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Check that txs exist
+ if !it.Next() {
+ t.Fatal("expected two elems, got zero")
+ }
+
+ txs := it.Value()
+
+ // Check that uncles exist
+ if !it.Next() {
+ t.Fatal("expected two elems, got one")
+ }
+
+ txit, err := NewListIterator(txs)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var i = 0
+
+ for txit.Next() {
+ if txit.err != nil {
+ t.Fatal(txit.err)
+ }
+
+ i++
+ }
+
+ if exp := 2; i != exp {
+ t.Errorf("count wrong, expected %d got %d", i, exp)
+ }
+}
diff --git a/helper/rlp/raw.go b/helper/rlp/raw.go
new file mode 100644
index 0000000000..930c31066d
--- /dev/null
+++ b/helper/rlp/raw.go
@@ -0,0 +1,317 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "io"
+ "reflect"
+)
+
+// RawValue represents an encoded RLP value and can be used to delay
+// RLP decoding or to precompute an encoding. Note that the decoder does
+// not verify whether the content of RawValues is valid RLP.
+type RawValue []byte
+
+var rawValueType = reflect.TypeOf(RawValue{})
+
+// StringSize returns the encoded size of a string.
+func StringSize(s string) uint64 {
+ switch {
+ case len(s) == 0:
+ return 1
+ case len(s) == 1:
+ if s[0] <= 0x7f {
+ return 1
+ } else {
+ return 2
+ }
+ default:
+ return uint64(headsize(uint64(len(s))) + len(s))
+ }
+}
+
+// BytesSize returns the encoded size of a byte slice.
+func BytesSize(b []byte) uint64 {
+ switch {
+ case len(b) == 0:
+ return 1
+ case len(b) == 1:
+ if b[0] <= 0x7f {
+ return 1
+ } else {
+ return 2
+ }
+ default:
+ return uint64(headsize(uint64(len(b))) + len(b))
+ }
+}
+
+// ListSize returns the encoded size of an RLP list with the given
+// content size.
+func ListSize(contentSize uint64) uint64 {
+ return uint64(headsize(contentSize)) + contentSize
+}
+
+// IntSize returns the encoded size of the integer x. Note: The return type of this
+// function is 'int' for backwards-compatibility reasons. The result is always positive.
+func IntSize(x uint64) int {
+ if x < 0x80 {
+ return 1
+ }
+
+ return 1 + intsize(x)
+}
+
+// Split returns the content of first RLP value and any
+// bytes after the value as subslices of b.
+func Split(b []byte) (k Kind, content, rest []byte, err error) {
+ k, ts, cs, err := readKind(b)
+ if err != nil {
+ return 0, nil, b, err
+ }
+
+ return k, b[ts : ts+cs], b[ts+cs:], nil
+}
+
+// SplitString splits b into the content of an RLP string
+// and any remaining bytes after the string.
+func SplitString(b []byte) (content, rest []byte, err error) {
+ k, content, rest, err := Split(b)
+ if err != nil {
+ return nil, b, err
+ }
+
+ if k == List {
+ return nil, b, ErrExpectedString
+ }
+
+ return content, rest, nil
+}
+
+// SplitUint64 decodes an integer at the beginning of b.
+// It also returns the remaining data after the integer in 'rest'.
+func SplitUint64(b []byte) (x uint64, rest []byte, err error) {
+ content, rest, err := SplitString(b)
+ if err != nil {
+ return 0, b, err
+ }
+
+ switch {
+ case len(content) == 0:
+ return 0, rest, nil
+ case len(content) == 1:
+ if content[0] == 0 {
+ return 0, b, ErrCanonInt
+ }
+
+ return uint64(content[0]), rest, nil
+ case len(content) > 8:
+ return 0, b, errUintOverflow
+ default:
+ x, err = readSize(content, byte(len(content)))
+ if err != nil {
+ return 0, b, ErrCanonInt
+ }
+
+ return x, rest, nil
+ }
+}
+
+// SplitList splits b into the content of a list and any remaining
+// bytes after the list.
+func SplitList(b []byte) (content, rest []byte, err error) {
+ k, content, rest, err := Split(b)
+ if err != nil {
+ return nil, b, err
+ }
+
+ if k != List {
+ return nil, b, ErrExpectedList
+ }
+
+ return content, rest, nil
+}
+
+// CountValues counts the number of encoded values in b.
+func CountValues(b []byte) (int, error) {
+ i := 0
+
+ for ; len(b) > 0; i++ {
+ _, tagsize, size, err := readKind(b)
+ if err != nil {
+ return 0, err
+ }
+
+ b = b[tagsize+size:]
+ }
+
+ return i, nil
+}
+
+func readKind(buf []byte) (k Kind, tagsize, contentsize uint64, err error) {
+ if len(buf) == 0 {
+ return 0, 0, 0, io.ErrUnexpectedEOF
+ }
+
+ b := buf[0]
+
+ switch {
+ case b < 0x80:
+ k = Byte
+ tagsize = 0
+ contentsize = 1
+ case b < 0xB8:
+ k = String
+ tagsize = 1
+ contentsize = uint64(b - 0x80)
+ // Reject strings that should've been single bytes.
+ if contentsize == 1 && len(buf) > 1 && buf[1] < 128 {
+ return 0, 0, 0, ErrCanonSize
+ }
+ case b < 0xC0:
+ k = String
+ tagsize = uint64(b-0xB7) + 1
+ contentsize, err = readSize(buf[1:], b-0xB7)
+ case b < 0xF8:
+ k = List
+ tagsize = 1
+ contentsize = uint64(b - 0xC0)
+ default:
+ k = List
+ tagsize = uint64(b-0xF7) + 1
+ contentsize, err = readSize(buf[1:], b-0xF7)
+ }
+
+ if err != nil {
+ return 0, 0, 0, err
+ }
+ // Reject values larger than the input slice.
+ if contentsize > uint64(len(buf))-tagsize {
+ return 0, 0, 0, ErrValueTooLarge
+ }
+
+ return k, tagsize, contentsize, err
+}
+
+func readSize(b []byte, slen byte) (uint64, error) {
+ if int(slen) > len(b) {
+ return 0, io.ErrUnexpectedEOF
+ }
+
+ var s uint64
+
+ switch slen {
+ case 1:
+ s = uint64(b[0])
+ case 2:
+ s = uint64(b[0])<<8 | uint64(b[1])
+ case 3:
+ s = uint64(b[0])<<16 | uint64(b[1])<<8 | uint64(b[2])
+ case 4:
+ s = uint64(b[0])<<24 | uint64(b[1])<<16 | uint64(b[2])<<8 | uint64(b[3])
+ case 5:
+ s = uint64(b[0])<<32 | uint64(b[1])<<24 | uint64(b[2])<<16 | uint64(b[3])<<8 | uint64(b[4])
+ case 6:
+ s = uint64(b[0])<<40 | uint64(b[1])<<32 | uint64(b[2])<<24 | uint64(b[3])<<16 | uint64(b[4])<<8 |
+ uint64(b[5])
+ case 7:
+ s = uint64(b[0])<<48 | uint64(b[1])<<40 | uint64(b[2])<<32 | uint64(b[3])<<24 | uint64(b[4])<<16 |
+ uint64(b[5])<<8 | uint64(b[6])
+ case 8:
+ s = uint64(b[0])<<56 | uint64(b[1])<<48 | uint64(b[2])<<40 | uint64(b[3])<<32 | uint64(b[4])<<24 |
+ uint64(b[5])<<16 | uint64(b[6])<<8 | uint64(b[7])
+ }
+ // Reject sizes < 56 (shouldn't have separate size) and sizes with
+ // leading zero bytes.
+ if s < 56 || b[0] == 0 {
+ return 0, ErrCanonSize
+ }
+
+ return s, nil
+}
+
+// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice.
+func AppendUint64(b []byte, i uint64) []byte {
+ if i == 0 {
+ return append(b, 0x80)
+ } else if i < 128 {
+ return append(b, byte(i))
+ }
+
+ switch {
+ case i < (1 << 8):
+ return append(b, 0x81, byte(i))
+ case i < (1 << 16):
+ return append(b, 0x82,
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 24):
+ return append(b, 0x83,
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 32):
+ return append(b, 0x84,
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 40):
+ return append(b, 0x85,
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+
+ case i < (1 << 48):
+ return append(b, 0x86,
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ case i < (1 << 56):
+ return append(b, 0x87,
+ byte(i>>48),
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+
+ default:
+ return append(b, 0x88,
+ byte(i>>56),
+ byte(i>>48),
+ byte(i>>40),
+ byte(i>>32),
+ byte(i>>24),
+ byte(i>>16),
+ byte(i>>8),
+ byte(i),
+ )
+ }
+}
diff --git a/helper/rlp/raw_test.go b/helper/rlp/raw_test.go
new file mode 100644
index 0000000000..17310310d5
--- /dev/null
+++ b/helper/rlp/raw_test.go
@@ -0,0 +1,349 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+ "testing/quick"
+)
+
+func TestCountValues(t *testing.T) {
+ tests := []struct {
+ input string // note: spaces in input are stripped by unhex
+ count int
+ err error
+ }{
+ // simple cases
+ {"", 0, nil},
+ {"00", 1, nil},
+ {"80", 1, nil},
+ {"C0", 1, nil},
+ {"01 02 03", 3, nil},
+ {"01 C406070809 02", 3, nil},
+ {"820101 820202 8403030303 04", 4, nil},
+
+ // size errors
+ {"8142", 0, ErrCanonSize},
+ {"01 01 8142", 0, ErrCanonSize},
+ {"02 84020202", 0, ErrValueTooLarge},
+
+ {
+ input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
+ count: 2,
+ },
+ }
+ for i, test := range tests {
+ count, err := CountValues(unhex(test.input))
+ if count != test.count {
+ t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input)
+ }
+
+ if !errors.Is(err, test.err) {
+ t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input)
+ }
+ }
+}
+
+func TestSplitString(t *testing.T) {
+ for i, test := range []string{
+ "C0",
+ "C100",
+ "C3010203",
+ "C88363617483646F67",
+ "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974",
+ } {
+ if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) {
+ t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString)
+ }
+ }
+}
+
+func TestSplitList(t *testing.T) {
+ for i, test := range []string{
+ "80",
+ "00",
+ "01",
+ "8180",
+ "81FF",
+ "820400",
+ "83636174",
+ "83646F67",
+ "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974",
+ } {
+ if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) {
+ t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList)
+ }
+ }
+}
+
+func TestSplitUint64(t *testing.T) {
+ tests := []struct {
+ input string
+ val uint64
+ rest string
+ err error
+ }{
+ {"01", 1, "", nil},
+ {"7FFF", 0x7F, "FF", nil},
+ {"80FF", 0, "FF", nil},
+ {"81FAFF", 0xFA, "FF", nil},
+ {"82FAFAFF", 0xFAFA, "FF", nil},
+ {"83FAFAFAFF", 0xFAFAFA, "FF", nil},
+ {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil},
+ {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil},
+ {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil},
+ {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil},
+ {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil},
+
+ // errors
+ {"", 0, "", io.ErrUnexpectedEOF},
+ {"00", 0, "00", ErrCanonInt},
+ {"81", 0, "81", ErrValueTooLarge},
+ {"8100", 0, "8100", ErrCanonSize},
+ {"8200FF", 0, "8200FF", ErrCanonInt},
+ {"8103FF", 0, "8103FF", ErrCanonSize},
+ {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow},
+ }
+
+ for i, test := range tests {
+ val, rest, err := SplitUint64(unhex(test.input))
+ if val != test.val {
+ t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input)
+ }
+
+ if !bytes.Equal(rest, unhex(test.rest)) {
+ t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input)
+ }
+
+ if !errors.Is(err, test.err) {
+ t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+ }
+ }
+}
+
+func TestSplit(t *testing.T) {
+ tests := []struct {
+ input string
+ kind Kind
+ val, rest string
+ err error
+ }{
+ {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"},
+ {input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"},
+ {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"},
+ {input: "80FFFF", kind: String, val: "", rest: "FFFF"},
+ {input: "C3010203", kind: List, val: "010203"},
+
+ // errors
+ {input: "", err: io.ErrUnexpectedEOF},
+
+ {input: "8141", err: ErrCanonSize, rest: "8141"},
+ {input: "B800", err: ErrCanonSize, rest: "B800"},
+ {input: "B802FFFF", err: ErrCanonSize, rest: "B802FFFF"},
+ {input: "B90000", err: ErrCanonSize, rest: "B90000"},
+ {input: "B90055", err: ErrCanonSize, rest: "B90055"},
+ {input: "BA0002FFFF", err: ErrCanonSize, rest: "BA0002FFFF"},
+ {input: "F800", err: ErrCanonSize, rest: "F800"},
+ {input: "F90000", err: ErrCanonSize, rest: "F90000"},
+ {input: "F90055", err: ErrCanonSize, rest: "F90055"},
+ {input: "FA0002FFFF", err: ErrCanonSize, rest: "FA0002FFFF"},
+
+ {input: "81", err: ErrValueTooLarge, rest: "81"},
+ {input: "8501010101", err: ErrValueTooLarge, rest: "8501010101"},
+ {input: "C60607080902", err: ErrValueTooLarge, rest: "C60607080902"},
+
+ // size check overflow
+ {input: "BFFFFFFFFFFFFFFFFF", err: ErrValueTooLarge, rest: "BFFFFFFFFFFFFFFFFF"},
+ {input: "FFFFFFFFFFFFFFFFFF", err: ErrValueTooLarge, rest: "FFFFFFFFFFFFFFFFFF"},
+
+ {
+ input: "B838FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ err: ErrValueTooLarge,
+ rest: "B838FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ },
+ {
+ input: "F838FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ err: ErrValueTooLarge,
+ rest: "F838FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ },
+
+ // a few bigger values, just for kicks
+ {
+ input: "F839FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ kind: List,
+ val: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
+ rest: "",
+ },
+ {
+ input: "F90211A060EF29F20CC1007AE6E9530AEE16F4B31F8F1769A2D1264EC995C6D1241868D6A07C62AB8AC9838F5F5877B20BB37B387BC2106E97A3D52172CBEDB5EE17C36008A00EAB6B7324AADC0F6047C6AFC8229F09F7CF451B51D67C8DFB08D49BA8C3C626A04453343B2F3A6E42FCF87948F88AF7C8FC16D0C2735CBA7F026836239AB2C15FA024635C7291C882CE4C0763760C1A362DFC3FFCD802A55722236DE058D74202ACA0A220C808DE10F55E40AB25255201CFF009EA181D3906638E944EE2BF34049984A08D325AB26796F1CCB470F69C0F842501DC35D368A0C2575B2D243CFD1E8AB0FDA0B5298FF60DA5069463D610513C9F04F24051348391A143AFFAB7197DFACDEA72A02D2A7058A4463F8FB69378369E11EF33AE3252E2DB86CB545B36D3C26DDECE5AA0888F97BCA8E0BD83DC5B3B91CFF5FAF2F66F9501010682D67EF4A3B4E66115FBA0E8175A60C93BE9ED02921958F0EA55DA0FB5E4802AF5846147BAD92BC2D8AF26A08B3376FF433F3A4250FA64B7F804004CAC5807877D91C4427BD1CD05CF912ED8A09B32EF0F03BD13C37FF950C0CCCEFCCDD6669F2E7F2AA5CB859928E84E29763EA09BBA5E46610C8C8B1F8E921E5691BF8C7E40D75825D5EA3217AA9C3A8A355F39A0EEB95BC78251CCCEC54A97F19755C4A59A293544EEE6119AFA50531211E53C4FA00B6E86FE150BF4A9E0FEEE9C90F5465E617A861BB5E357F942881EE762212E2580",
+ kind: List,
+ val: "A060EF29F20CC1007AE6E9530AEE16F4B31F8F1769A2D1264EC995C6D1241868D6A07C62AB8AC9838F5F5877B20BB37B387BC2106E97A3D52172CBEDB5EE17C36008A00EAB6B7324AADC0F6047C6AFC8229F09F7CF451B51D67C8DFB08D49BA8C3C626A04453343B2F3A6E42FCF87948F88AF7C8FC16D0C2735CBA7F026836239AB2C15FA024635C7291C882CE4C0763760C1A362DFC3FFCD802A55722236DE058D74202ACA0A220C808DE10F55E40AB25255201CFF009EA181D3906638E944EE2BF34049984A08D325AB26796F1CCB470F69C0F842501DC35D368A0C2575B2D243CFD1E8AB0FDA0B5298FF60DA5069463D610513C9F04F24051348391A143AFFAB7197DFACDEA72A02D2A7058A4463F8FB69378369E11EF33AE3252E2DB86CB545B36D3C26DDECE5AA0888F97BCA8E0BD83DC5B3B91CFF5FAF2F66F9501010682D67EF4A3B4E66115FBA0E8175A60C93BE9ED02921958F0EA55DA0FB5E4802AF5846147BAD92BC2D8AF26A08B3376FF433F3A4250FA64B7F804004CAC5807877D91C4427BD1CD05CF912ED8A09B32EF0F03BD13C37FF950C0CCCEFCCDD6669F2E7F2AA5CB859928E84E29763EA09BBA5E46610C8C8B1F8E921E5691BF8C7E40D75825D5EA3217AA9C3A8A355F39A0EEB95BC78251CCCEC54A97F19755C4A59A293544EEE6119AFA50531211E53C4FA00B6E86FE150BF4A9E0FEEE9C90F5465E617A861BB5E357F942881EE762212E2580",
+ rest: "",
+ },
+ {
+ input: "F877A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
+ kind: List,
+ val: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
+ rest: "",
+ },
+ }
+
+ for i, test := range tests {
+ kind, val, rest, err := Split(unhex(test.input))
+ if kind != test.kind {
+ t.Errorf("test %d: kind mismatch: got %v, want %v", i, kind, test.kind)
+ }
+
+ if !bytes.Equal(val, unhex(test.val)) {
+ t.Errorf("test %d: val mismatch: got %x, want %s", i, val, test.val)
+ }
+
+ if !bytes.Equal(rest, unhex(test.rest)) {
+ t.Errorf("test %d: rest mismatch: got %x, want %s", i, rest, test.rest)
+ }
+
+ if !errors.Is(err, test.err) {
+ t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+ }
+ }
+}
+
+func TestReadSize(t *testing.T) {
+ tests := []struct {
+ input string
+ slen byte
+ size uint64
+ err error
+ }{
+ {input: "", slen: 1, err: io.ErrUnexpectedEOF},
+ {input: "FF", slen: 2, err: io.ErrUnexpectedEOF},
+ {input: "00", slen: 1, err: ErrCanonSize},
+ {input: "36", slen: 1, err: ErrCanonSize},
+ {input: "37", slen: 1, err: ErrCanonSize},
+ {input: "38", slen: 1, size: 0x38},
+ {input: "FF", slen: 1, size: 0xFF},
+ {input: "FFFF", slen: 2, size: 0xFFFF},
+ {input: "FFFFFF", slen: 3, size: 0xFFFFFF},
+ {input: "FFFFFFFF", slen: 4, size: 0xFFFFFFFF},
+ {input: "FFFFFFFFFF", slen: 5, size: 0xFFFFFFFFFF},
+ {input: "FFFFFFFFFFFF", slen: 6, size: 0xFFFFFFFFFFFF},
+ {input: "FFFFFFFFFFFFFF", slen: 7, size: 0xFFFFFFFFFFFFFF},
+ {input: "FFFFFFFFFFFFFFFF", slen: 8, size: 0xFFFFFFFFFFFFFFFF},
+ {input: "0102", slen: 2, size: 0x0102},
+ {input: "010203", slen: 3, size: 0x010203},
+ {input: "01020304", slen: 4, size: 0x01020304},
+ {input: "0102030405", slen: 5, size: 0x0102030405},
+ {input: "010203040506", slen: 6, size: 0x010203040506},
+ {input: "01020304050607", slen: 7, size: 0x01020304050607},
+ {input: "0102030405060708", slen: 8, size: 0x0102030405060708},
+ }
+
+ for _, test := range tests {
+ size, err := readSize(unhex(test.input), test.slen)
+ if !errors.Is(err, test.err) {
+ t.Errorf("readSize(%s, %d): error mismatch: got %q, want %q", test.input, test.slen, err, test.err)
+
+ continue
+ }
+
+ if size != test.size {
+ t.Errorf("readSize(%s, %d): size mismatch: got %#x, want %#x", test.input, test.slen, size, test.size)
+ }
+ }
+}
+
+func TestAppendUint64(t *testing.T) {
+ tests := []struct {
+ input uint64
+ slice []byte
+ output string
+ }{
+ {0, nil, "80"},
+ {1, nil, "01"},
+ {2, nil, "02"},
+ {127, nil, "7F"},
+ {128, nil, "8180"},
+ {129, nil, "8181"},
+ {0xFFFFFF, nil, "83FFFFFF"},
+ {127, []byte{1, 2, 3}, "0102037F"},
+ {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"},
+ }
+
+ for _, test := range tests {
+ x := AppendUint64(test.slice, test.input)
+ if !bytes.Equal(x, unhex(test.output)) {
+ t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output)
+ }
+
+ // Check that IntSize returns the appended size.
+ length := len(x) - len(test.slice)
+ if s := IntSize(test.input); s != length {
+ t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length)
+ }
+ }
+}
+
+func TestAppendUint64Random(t *testing.T) {
+ fn := func(i uint64) bool {
+ enc, _ := EncodeToBytes(i)
+ encAppend := AppendUint64(nil, i)
+
+ return bytes.Equal(enc, encAppend)
+ }
+ config := quick.Config{MaxCountScale: 50}
+
+ if err := quick.Check(fn, &config); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestBytesSize(t *testing.T) {
+ tests := []struct {
+ v []byte
+ size uint64
+ }{
+ {v: []byte{}, size: 1},
+ {v: []byte{0x1}, size: 1},
+ {v: []byte{0x7E}, size: 1},
+ {v: []byte{0x7F}, size: 1},
+ {v: []byte{0x80}, size: 2},
+ {v: []byte{0xFF}, size: 2},
+ {v: []byte{0xFF, 0xF0}, size: 3},
+ {v: make([]byte, 55), size: 56},
+ {v: make([]byte, 56), size: 58},
+ }
+
+ for _, test := range tests {
+ s := BytesSize(test.v)
+ if s != test.size {
+ t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size)
+ }
+
+ s = StringSize(string(test.v))
+ if s != test.size {
+ t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size)
+ }
+ // Sanity check:
+ enc, _ := EncodeToBytes(test.v)
+ if uint64(len(enc)) != test.size {
+ t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size)
+ }
+ }
+}
diff --git a/helper/rlp/rlpgen/gen.go b/helper/rlp/rlpgen/gen.go
new file mode 100644
index 0000000000..5b3290f46a
--- /dev/null
+++ b/helper/rlp/rlpgen/gen.go
@@ -0,0 +1,845 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/format"
+ "go/types"
+ "sort"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp/internal/rlpstruct"
+)
+
+// buildContext keeps the data needed for make*Op.
+type buildContext struct {
+ topType *types.Named // the type we're creating methods for
+
+ encoderIface *types.Interface
+ decoderIface *types.Interface
+ rawValueType *types.Named
+
+ typeToStructCache map[types.Type]*rlpstruct.Type
+}
+
+func newBuildContext(packageRLP *types.Package) *buildContext {
+ enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying()
+ dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying()
+ rawv := packageRLP.Scope().Lookup("RawValue").Type()
+
+ //nolint:forcetypeassert
+ return &buildContext{
+ typeToStructCache: make(map[types.Type]*rlpstruct.Type),
+ encoderIface: enc.(*types.Interface),
+ decoderIface: dec.(*types.Interface),
+ rawValueType: rawv.(*types.Named),
+ }
+}
+
+func (bctx *buildContext) isEncoder(typ types.Type) bool {
+ return types.Implements(typ, bctx.encoderIface)
+}
+
+func (bctx *buildContext) isDecoder(typ types.Type) bool {
+ return types.Implements(typ, bctx.decoderIface)
+}
+
+// typeToStructType converts typ to rlpstruct.Type.
+func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
+ if prev := bctx.typeToStructCache[typ]; prev != nil {
+ return prev // short-circuit for recursive types.
+ }
+
+ // Resolve named types to their underlying type, but keep the name.
+ name := types.TypeString(typ, nil)
+
+ for {
+ utype := typ.Underlying()
+ if utype == typ {
+ break
+ }
+
+ typ = utype
+ }
+
+ // Create the type and store it in cache.
+ t := &rlpstruct.Type{
+ Name: name,
+ Kind: typeReflectKind(typ),
+ IsEncoder: bctx.isEncoder(typ),
+ IsDecoder: bctx.isDecoder(typ),
+ }
+ bctx.typeToStructCache[typ] = t
+
+ // Assign element type.
+ switch typ.(type) {
+ case *types.Array, *types.Slice, *types.Pointer:
+ //nolint:forcetypeassert
+ etype := typ.(interface{ Elem() types.Type }).Elem()
+ t.Elem = bctx.typeToStructType(etype)
+ }
+
+ return t
+}
+
+// genContext is passed to the gen* methods of op when generating
+// the output code. It tracks packages to be imported by the output
+// file and assigns unique names of temporary variables.
+type genContext struct {
+ inPackage *types.Package
+ imports map[string]struct{}
+ tempCounter int
+}
+
+func newGenContext(inPackage *types.Package) *genContext {
+ return &genContext{
+ inPackage: inPackage,
+ imports: make(map[string]struct{}),
+ }
+}
+
+func (ctx *genContext) temp() string {
+ v := fmt.Sprintf("_tmp%d", ctx.tempCounter)
+ ctx.tempCounter++
+
+ return v
+}
+
+func (ctx *genContext) resetTemp() {
+ ctx.tempCounter = 0
+}
+
+func (ctx *genContext) addImport(path string) {
+ if path == ctx.inPackage.Path() {
+ return // avoid importing the package that we're generating in.
+ }
+ // TODO: renaming?
+ ctx.imports[path] = struct{}{}
+}
+
+// importsList returns all packages that need to be imported.
+func (ctx *genContext) importsList() []string {
+ imp := make([]string, 0, len(ctx.imports))
+ for k := range ctx.imports {
+ imp = append(imp, k)
+ }
+
+ sort.Strings(imp)
+
+ return imp
+}
+
+// qualify is the types.Qualifier used for printing types.
+func (ctx *genContext) qualify(pkg *types.Package) string {
+ if pkg.Path() == ctx.inPackage.Path() {
+ return ""
+ }
+
+ ctx.addImport(pkg.Path())
+ // TODO: renaming?
+ return pkg.Name()
+}
+
+type op interface {
+ // genWrite creates the encoder. The generated code should write v,
+ // which is any Go expression, to the rlp.EncoderBuffer 'w'.
+ genWrite(ctx *genContext, v string) string
+
+ // genDecode creates the decoder. The generated code should read
+ // a value from the rlp.Stream 'dec' and store it to dst.
+ genDecode(ctx *genContext) (string, string)
+}
+
+// basicOp handles basic types bool, uint*, string.
+type basicOp struct {
+ typ types.Type
+ writeMethod string // calle write the value
+ writeArgType types.Type // parameter type of writeMethod
+ decMethod string
+ decResultType types.Type // return type of decMethod
+ decUseBitSize bool // if true, result bit size is appended to decMethod
+}
+
+func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) {
+ op := basicOp{typ: typ}
+ kind := typ.Kind()
+
+ switch {
+ case kind == types.Bool:
+ op.writeMethod = "WriteBool"
+ op.writeArgType = types.Typ[types.Bool]
+ op.decMethod = "Bool"
+ op.decResultType = types.Typ[types.Bool]
+ case kind >= types.Uint8 && kind <= types.Uint64:
+ op.writeMethod = "WriteUint64"
+ op.writeArgType = types.Typ[types.Uint64]
+ op.decMethod = "Uint"
+ op.decResultType = typ
+ op.decUseBitSize = true
+ case kind == types.String:
+ op.writeMethod = "WriteString"
+ op.writeArgType = types.Typ[types.String]
+ op.decMethod = "String"
+ op.decResultType = types.Typ[types.String]
+ default:
+ return nil, fmt.Errorf("unhandled basic type: %v", typ)
+ }
+
+ return op, nil
+}
+
+func (*buildContext) makeByteSliceOp(typ *types.Slice) op {
+ if !isByte(typ.Elem()) {
+ panic("non-byte slice type in makeByteSliceOp")
+ }
+
+ bslice := types.NewSlice(types.Typ[types.Uint8])
+
+ return basicOp{
+ typ: typ,
+ writeMethod: "WriteBytes",
+ writeArgType: bslice,
+ decMethod: "Bytes",
+ decResultType: bslice,
+ }
+}
+
+func (bctx *buildContext) makeRawValueOp() op {
+ bslice := types.NewSlice(types.Typ[types.Uint8])
+
+ return basicOp{
+ typ: bctx.rawValueType,
+ writeMethod: "Write",
+ writeArgType: bslice,
+ decMethod: "Raw",
+ decResultType: bslice,
+ }
+}
+
+func (op basicOp) writeNeedsConversion() bool {
+ return !types.AssignableTo(op.typ, op.writeArgType)
+}
+
+func (op basicOp) decodeNeedsConversion() bool {
+ return !types.AssignableTo(op.decResultType, op.typ)
+}
+
+func (op basicOp) genWrite(ctx *genContext, v string) string {
+ if op.writeNeedsConversion() {
+ v = fmt.Sprintf("%s(%s)", op.writeArgType, v)
+ }
+
+ return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v)
+}
+
+func (op basicOp) genDecode(ctx *genContext) (string, string) {
+ var (
+ resultV = ctx.temp()
+ result = resultV
+ method = op.decMethod
+ )
+
+ if op.decUseBitSize {
+ // Note: For now, this only works for platform-independent integer
+ // sizes. makeBasicOp forbids the platform-dependent types.
+ var sizes types.StdSizes
+ method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8)
+ }
+
+ // Call the decoder method.
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method)
+ fmt.Fprintf(&b, "if err != nil { return err }\n")
+
+ if op.decodeNeedsConversion() {
+ conv := ctx.temp()
+ fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV)
+ result = conv
+ }
+
+ return result, b.String()
+}
+
+// byteArrayOp handles [...]byte.
+type byteArrayOp struct {
+ typ types.Type
+ name types.Type // name != typ for named byte array types (e.g. common.Address)
+}
+
+func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp {
+ nt := types.Type(name)
+ if name == nil {
+ nt = typ
+ }
+
+ return byteArrayOp{typ, nt}
+}
+
+func (op byteArrayOp) genWrite(ctx *genContext, v string) string {
+ return fmt.Sprintf("w.WriteBytes(%s[:])\n", v)
+}
+
+func (op byteArrayOp) genDecode(ctx *genContext) (string, string) {
+ var (
+ resultV = ctx.temp()
+ b bytes.Buffer
+ )
+
+ fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify))
+ fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV)
+
+ return resultV, b.String()
+}
+
+// bigIntNoPtrOp handles non-pointer big.Int.
+// This exists because big.Int has it's own decoder operation on rlp.Stream,
+// but the decode method returns *big.Int, so it needs to be dereferenced.
+type bigIntOp struct {
+ pointer bool
+}
+
+func (op bigIntOp) genWrite(ctx *genContext, v string) string {
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v)
+ fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n")
+ fmt.Fprintf(&b, "}\n")
+
+ dst := v
+ if !op.pointer {
+ dst = "&" + v
+ }
+
+ fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst)
+
+ // Wrap with nil check.
+ if op.pointer {
+ code := b.String()
+ b.Reset()
+ fmt.Fprintf(&b, "if %s == nil {\n", v)
+ fmt.Fprintf(&b, " w.Write(rlp.EmptyString)")
+ fmt.Fprintf(&b, "} else {\n")
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, "}\n")
+ }
+
+ return b.String()
+}
+
+func (op bigIntOp) genDecode(ctx *genContext) (string, string) {
+ var resultV = ctx.temp()
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV)
+ fmt.Fprintf(&b, "if err != nil { return err }\n")
+
+ result := resultV
+ if !op.pointer {
+ result = "(*" + resultV + ")"
+ }
+
+ return result, b.String()
+}
+
+// encoderDecoderOp handles rlp.Encoder and rlp.Decoder.
+// In order to be used with this, the type must implement both interfaces.
+// This restriction may be lifted in the future by creating separate ops for
+// encoding and decoding.
+type encoderDecoderOp struct {
+ typ types.Type
+}
+
+func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string {
+ return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v)
+}
+
+func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) {
+ var (
+ // DecodeRLP must have pointer receiver, and this is verified in makeOp.
+ //nolint:forcetypeassert
+ etyp = op.typ.(*types.Pointer).Elem()
+ resultV = ctx.temp()
+ b bytes.Buffer
+ )
+
+ fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify))
+ fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV)
+
+ return resultV, b.String()
+}
+
+// ptrOp handles pointer types.
+type ptrOp struct {
+ elemTyp types.Type
+ elem op
+ nilOK bool
+ nilValue rlpstruct.NilKind
+}
+
+func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) {
+ elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+
+ op := ptrOp{elemTyp: elemTyp, elem: elemOp}
+
+ // Determine nil value.
+ if tags.NilOK {
+ op.nilOK = true
+ op.nilValue = tags.NilKind
+ } else {
+ styp := bctx.typeToStructType(elemTyp)
+ op.nilValue = styp.DefaultNilValue()
+ }
+
+ return op, nil
+}
+
+func (op ptrOp) genWrite(ctx *genContext, v string) string {
+ // Note: in writer functions, accesses to v are read-only, i.e. v is any Go
+ // expression. To make all accesses work through the pointer, we substitute
+ // v with (*v). This is required for most accesses including `v`, `call(v)`,
+ // and `v[index]` on slices.
+ //
+ // For `v.field` and `v[:]` on arrays, the dereference operation is not required.
+ var vv string
+
+ _, isStruct := op.elem.(structOp)
+ _, isByteArray := op.elem.(byteArrayOp)
+
+ if isStruct || isByteArray {
+ vv = v
+ } else {
+ vv = fmt.Sprintf("(*%s)", v)
+ }
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "if %s == nil {\n", v)
+ fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue)
+ fmt.Fprintf(&b, "} else {\n")
+ fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv))
+ fmt.Fprintf(&b, "}\n")
+
+ return b.String()
+}
+
+func (op ptrOp) genDecode(ctx *genContext) (string, string) {
+ result, code := op.elem.genDecode(ctx)
+ if !op.nilOK {
+ // If nil pointers are not allowed, we can just decode the element.
+ return "&" + result, code
+ }
+
+ // nil is allowed, so check the kind and size first.
+ // If size is zero and kind matches the nilKind of the type,
+ // the value decodes as a nil pointer.
+ var (
+ resultV = ctx.temp()
+ kindV = ctx.temp()
+ sizeV = ctx.temp()
+ wantKind string
+ )
+
+ if op.nilValue == rlpstruct.NilKindList {
+ wantKind = "rlp.List"
+ } else {
+ wantKind = "rlp.String"
+ }
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify))
+ fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV)
+ fmt.Fprintf(&b, " return err\n")
+ fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, " %s = &%s\n", resultV, result)
+ fmt.Fprintf(&b, "}\n")
+
+ return resultV, b.String()
+}
+
+// structOp handles struct types.
+type structOp struct {
+ named *types.Named
+ typ *types.Struct
+ fields []*structField
+ optionalFields []*structField
+}
+
+type structField struct {
+ name string
+ typ types.Type
+ elem op
+}
+
+func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) {
+ // Convert fields to []rlpstruct.Field.
+ var allStructFields []rlpstruct.Field
+
+ for i := 0; i < typ.NumFields(); i++ {
+ f := typ.Field(i)
+
+ allStructFields = append(allStructFields, rlpstruct.Field{
+ Name: f.Name(),
+ Exported: f.Exported(),
+ Index: i,
+ Tag: typ.Tag(i),
+ Type: *bctx.typeToStructType(f.Type()),
+ })
+ }
+
+ // Filter/validate fields.
+ fields, tags, err := rlpstruct.ProcessFields(allStructFields)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create field ops.
+ var op = structOp{named: named, typ: typ}
+
+ for i, field := range fields {
+ // Advanced struct tags are not supported yet.
+ tag := tags[i]
+ if err := checkUnsupportedTags(field.Name, tag); err != nil {
+ return nil, err
+ }
+
+ typ := typ.Field(field.Index).Type()
+
+ elem, err := bctx.makeOp(nil, typ, tags[i])
+ if err != nil {
+ return nil, fmt.Errorf("field %s: %w", field.Name, err)
+ }
+
+ f := &structField{name: field.Name, typ: typ, elem: elem}
+ if tag.Optional {
+ op.optionalFields = append(op.optionalFields, f)
+ } else {
+ op.fields = append(op.fields, f)
+ }
+ }
+
+ return op, nil
+}
+
+func checkUnsupportedTags(field string, tag rlpstruct.Tags) error {
+ if tag.Tail {
+ return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field)
+ }
+
+ return nil
+}
+
+func (op structOp) genWrite(ctx *genContext, v string) string {
+ var (
+ b bytes.Buffer
+ listMarker = ctx.temp()
+ )
+
+ fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
+
+ for _, field := range op.fields {
+ selector := v + "." + field.name
+ fmt.Fprint(&b, field.elem.genWrite(ctx, selector))
+ }
+
+ op.writeOptionalFields(&b, ctx, v)
+ fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
+
+ return b.String()
+}
+
+func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) {
+ if len(op.optionalFields) == 0 {
+ return
+ }
+
+ // First check zero-ness of all optional fields.
+ var zeroV = make([]string, len(op.optionalFields))
+
+ for i, field := range op.optionalFields {
+ selector := v + "." + field.name
+ zeroV[i] = ctx.temp()
+ fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify))
+ }
+
+ // Now write the fields.
+ for i, field := range op.optionalFields {
+ selector := v + "." + field.name
+ cond := ""
+
+ for j := i; j < len(op.optionalFields); j++ {
+ if j > i {
+ cond += " || "
+ }
+
+ cond += zeroV[j]
+ }
+
+ fmt.Fprintf(b, "if %s {\n", cond)
+ fmt.Fprint(b, field.elem.genWrite(ctx, selector))
+ fmt.Fprintf(b, "}\n")
+ }
+}
+
+func (op structOp) genDecode(ctx *genContext) (string, string) {
+ // Get the string representation of the type.
+ // Here, named types are handled separately because the output
+ // would contain a copy of the struct definition otherwise.
+ var typeName string
+ if op.named != nil {
+ typeName = types.TypeString(op.named, ctx.qualify)
+ } else {
+ typeName = types.TypeString(op.typ, ctx.qualify)
+ }
+
+ // Create struct object.
+ var (
+ resultV = ctx.temp()
+ b bytes.Buffer
+ )
+
+ fmt.Fprintf(&b, "var %s %s\n", resultV, typeName)
+
+ // Decode fields.
+ fmt.Fprintf(&b, "{\n")
+ fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
+
+ for _, field := range op.fields {
+ result, code := field.elem.genDecode(ctx)
+ fmt.Fprintf(&b, "// %s:\n", field.name)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result)
+ }
+
+ op.decodeOptionalFields(&b, ctx, resultV)
+ fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
+ fmt.Fprintf(&b, "}\n")
+
+ return resultV, b.String()
+}
+
+func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) {
+ var suffix bytes.Buffer
+
+ for _, field := range op.optionalFields {
+ result, code := field.elem.genDecode(ctx)
+ fmt.Fprintf(b, "// %s:\n", field.name)
+ fmt.Fprintf(b, "if dec.MoreDataInList() {\n")
+ fmt.Fprint(b, code)
+ fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result)
+ fmt.Fprintf(&suffix, "}\n")
+ }
+
+ suffix.WriteTo(b)
+}
+
+// sliceOp handles slice types.
+type sliceOp struct {
+ typ *types.Slice
+ elemOp op
+}
+
+func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) {
+ elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+
+ return sliceOp{typ: typ, elemOp: elemOp}, nil
+}
+
+func (op sliceOp) genWrite(ctx *genContext, v string) string {
+ var (
+ listMarker = ctx.temp() // holds return value of w.List()
+ iterElemV = ctx.temp() // iteration variable
+ elemCode = op.elemOp.genWrite(ctx, iterElemV)
+ b bytes.Buffer
+ )
+
+ fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
+ fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v)
+ fmt.Fprint(&b, elemCode)
+ fmt.Fprintf(&b, "}\n")
+ fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
+
+ return b.String()
+}
+
+func (op sliceOp) genDecode(ctx *genContext) (string, string) {
+ var (
+ sliceV = ctx.temp() // holds the output slice
+ elemResult, elemCode = op.elemOp.genDecode(ctx)
+ b bytes.Buffer
+ )
+
+ fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify))
+ fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
+ fmt.Fprintf(&b, "for dec.MoreDataInList() {\n")
+ fmt.Fprintf(&b, " %s", elemCode)
+ fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult)
+ fmt.Fprintf(&b, "}\n")
+ fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
+
+ return sliceV, b.String()
+}
+
+func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) {
+ switch typ := typ.(type) {
+ case *types.Named:
+ if isBigInt(typ) {
+ return bigIntOp{}, nil
+ }
+
+ if typ == bctx.rawValueType {
+ return bctx.makeRawValueOp(), nil
+ }
+
+ if bctx.isDecoder(typ) {
+ return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ)
+ }
+ // TODO: same check for encoder?
+ return bctx.makeOp(typ, typ.Underlying(), tags)
+ case *types.Pointer:
+ if isBigInt(typ.Elem()) {
+ return bigIntOp{pointer: true}, nil
+ }
+ // Encoder/Decoder interfaces.
+ if bctx.isEncoder(typ) {
+ if bctx.isDecoder(typ) {
+ return encoderDecoderOp{typ}, nil
+ }
+
+ return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ)
+ }
+
+ if bctx.isDecoder(typ) {
+ return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ)
+ }
+ // Default pointer handling.
+ return bctx.makePtrOp(typ.Elem(), tags)
+ case *types.Basic:
+ return bctx.makeBasicOp(typ)
+ case *types.Struct:
+ return bctx.makeStructOp(name, typ)
+ case *types.Slice:
+ etyp := typ.Elem()
+ if isByte(etyp) && !bctx.isEncoder(etyp) {
+ return bctx.makeByteSliceOp(typ), nil
+ }
+
+ return bctx.makeSliceOp(typ)
+ case *types.Array:
+ etyp := typ.Elem()
+
+ if isByte(etyp) && !bctx.isEncoder(etyp) {
+ return bctx.makeByteArrayOp(name, typ), nil
+ }
+
+ return nil, fmt.Errorf("unhandled array type: %v", typ)
+ default:
+ return nil, fmt.Errorf("unhandled type: %v", typ)
+ }
+}
+
+// generateDecoder generates the DecodeRLP method on 'typ'.
+func generateDecoder(ctx *genContext, typ string, op op) []byte {
+ ctx.resetTemp()
+ ctx.addImport(pathOfPackageRLP)
+
+ result, code := op.genDecode(ctx)
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ)
+ fmt.Fprint(&b, code)
+ fmt.Fprintf(&b, " *obj = %s\n", result)
+ fmt.Fprintf(&b, " return nil\n")
+ fmt.Fprintf(&b, "}\n")
+
+ return b.Bytes()
+}
+
+// generateEncoder generates the EncodeRLP method on 'typ'.
+func generateEncoder(ctx *genContext, typ string, op op) []byte {
+ ctx.resetTemp()
+ ctx.addImport("io")
+ ctx.addImport(pathOfPackageRLP)
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
+ fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n")
+ fmt.Fprint(&b, op.genWrite(ctx, "obj"))
+ fmt.Fprintf(&b, " return w.Flush()\n")
+ fmt.Fprintf(&b, "}\n")
+
+ return b.Bytes()
+}
+
+func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) {
+ bctx.topType = typ
+
+ pkg := typ.Obj().Pkg()
+
+ op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{})
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ ctx = newGenContext(pkg)
+ encSource []byte
+ decSource []byte
+ )
+
+ if encoder {
+ encSource = generateEncoder(ctx, typ.Obj().Name(), op)
+ }
+
+ if decoder {
+ decSource = generateDecoder(ctx, typ.Obj().Name(), op)
+ }
+
+ var b bytes.Buffer
+
+ fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
+
+ for _, imp := range ctx.importsList() {
+ fmt.Fprintf(&b, "import %q\n", imp)
+ }
+
+ if encoder {
+ fmt.Fprintln(&b)
+ b.Write(encSource)
+ }
+
+ if decoder {
+ fmt.Fprintln(&b)
+ b.Write(decSource)
+ }
+
+ source := b.Bytes()
+ // fmt.Println(string(source))
+ return format.Source(source)
+}
diff --git a/helper/rlp/rlpgen/gen_test.go b/helper/rlp/rlpgen/gen_test.go
new file mode 100644
index 0000000000..90f65e9e80
--- /dev/null
+++ b/helper/rlp/rlpgen/gen_test.go
@@ -0,0 +1,113 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/importer"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+// Package RLP is loaded only once and reused for all tests.
+var (
+ testFset = token.NewFileSet()
+ testImporter, _ = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom)
+ testPackageRLP *types.Package
+)
+
+func init() {
+ cwd, err := os.Getwd()
+ if err != nil {
+ panic(err)
+ }
+
+ testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0)
+ if err != nil {
+ panic(fmt.Errorf("can't load package RLP: %w", err))
+ }
+}
+
+var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint"}
+
+func TestOutput(t *testing.T) {
+ for _, test := range tests {
+ test := test
+ t.Run(test, func(t *testing.T) {
+ inputFile := filepath.Join("testdata", test+".in.txt")
+ outputFile := filepath.Join("testdata", test+".out.txt")
+ bctx, typ, err := loadTestSource(inputFile, "Test")
+ if err != nil {
+ t.Fatal("error loading test source:", err)
+ }
+ output, err := bctx.generate(typ, true, true)
+ if err != nil {
+ t.Fatal("error in generate:", err)
+ }
+
+ // Set this environment variable to regenerate the test outputs.
+ if os.Getenv("WRITE_TEST_FILES") != "" {
+ os.WriteFile(outputFile, output, 0644)
+ }
+
+ // Check if output matches.
+ wantOutput, err := os.ReadFile(outputFile)
+ if err != nil {
+ t.Fatal("error loading expected test output:", err)
+ }
+ if !bytes.Equal(output, wantOutput) {
+ t.Fatal("output mismatch:\n", string(output))
+ }
+ })
+ }
+}
+
+func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) {
+ // Load the test input.
+ content, err := os.ReadFile(file)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ f, err := parser.ParseFile(testFset, file, content, 0)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ conf := types.Config{Importer: testImporter}
+
+ pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Find the test struct.
+ bctx := newBuildContext(testPackageRLP)
+
+ typ, err := lookupStructType(pkg.Scope(), typeName)
+ if err != nil {
+ return nil, nil, fmt.Errorf("can't find type %s: %w", typeName, err)
+ }
+
+ return bctx, typ, nil
+}
diff --git a/helper/rlp/rlpgen/main.go b/helper/rlp/rlpgen/main.go
new file mode 100644
index 0000000000..7ef992c659
--- /dev/null
+++ b/helper/rlp/rlpgen/main.go
@@ -0,0 +1,164 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "bytes"
+ "errors"
+ "flag"
+ "fmt"
+ "go/types"
+ "os"
+
+ "golang.org/x/tools/go/packages"
+)
+
+const pathOfPackageRLP = "github.com/dogechain-lab/dogechain/helper/rlp"
+
+func main() {
+ var (
+ pkgdir = flag.String("dir", ".", "input package")
+ output = flag.String("out", "-", "output file (default is stdout)")
+ genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?")
+ genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?")
+ typename = flag.String("type", "", "type to generate methods for")
+ )
+
+ flag.Parse()
+
+ cfg := Config{
+ Dir: *pkgdir,
+ Type: *typename,
+ GenerateEncoder: *genEncoder,
+ GenerateDecoder: *genDecoder,
+ }
+
+ code, err := cfg.process()
+ if err != nil {
+ fatal(err)
+ }
+
+ if *output == "-" {
+ os.Stdout.Write(code)
+ } else if err := os.WriteFile(*output, code, 0600); err != nil {
+ fatal(err)
+ }
+}
+
+func fatal(args ...interface{}) {
+ fmt.Fprintln(os.Stderr, args...)
+ os.Exit(1)
+}
+
+type Config struct {
+ Dir string // input package directory
+ Type string
+
+ GenerateEncoder bool
+ GenerateDecoder bool
+}
+
+// process generates the Go code.
+func (cfg *Config) process() (code []byte, err error) {
+ // Load packages.
+ pcfg := &packages.Config{
+ Mode: packages.NeedName | packages.NeedTypes | packages.NeedImports | packages.NeedDeps,
+ Dir: cfg.Dir,
+ BuildFlags: []string{"-tags", "norlpgen"},
+ }
+
+ ps, err := packages.Load(pcfg, pathOfPackageRLP, ".")
+ if err != nil {
+ return nil, err
+ }
+
+ if len(ps) == 0 {
+ return nil, fmt.Errorf("no Go package found in %s", cfg.Dir)
+ }
+
+ packages.PrintErrors(ps)
+
+ // Find the packages that were loaded.
+ var (
+ pkg *types.Package
+ packageRLP *types.Package
+ )
+
+ for _, p := range ps {
+ if len(p.Errors) > 0 {
+ return nil, fmt.Errorf("package %s has errors", p.PkgPath)
+ }
+
+ if p.PkgPath == pathOfPackageRLP {
+ packageRLP = p.Types
+ } else {
+ pkg = p.Types
+ }
+ }
+
+ bctx := newBuildContext(packageRLP)
+
+ // Find the type and generate.
+ typ, err := lookupStructType(pkg.Scope(), cfg.Type)
+ if err != nil {
+ return nil, fmt.Errorf("can't find %s in %s: %w", cfg.Type, pkg, err)
+ }
+
+ code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder)
+ if err != nil {
+ return nil, err
+ }
+
+ // Add build comments.
+ // This is done here to avoid processing these lines with gofmt.
+ var header bytes.Buffer
+
+ fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n")
+ fmt.Fprint(&header, "//go:build !norlpgen\n")
+ fmt.Fprint(&header, "// +build !norlpgen\n\n")
+
+ return append(header.Bytes(), code...), nil
+}
+
+func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
+ typ, err := lookupType(scope, name)
+ if err != nil {
+ return nil, err
+ }
+
+ _, ok := typ.Underlying().(*types.Struct)
+ if !ok {
+ return nil, errors.New("not a struct type")
+ }
+
+ return typ, nil
+}
+
+func lookupType(scope *types.Scope, name string) (*types.Named, error) {
+ obj := scope.Lookup(name)
+ if obj == nil {
+ return nil, errors.New("no such identifier")
+ }
+
+ typ, ok := obj.(*types.TypeName)
+ if !ok {
+ return nil, errors.New("not a type")
+ }
+
+ //nolint:forcetypeassert
+ return typ.Type().(*types.Named), nil
+}
diff --git a/helper/rlp/rlpgen/testdata/bigint.in.txt b/helper/rlp/rlpgen/testdata/bigint.in.txt
new file mode 100644
index 0000000000..d23d84a287
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/bigint.in.txt
@@ -0,0 +1,10 @@
+// -*- mode: go -*-
+
+package test
+
+import "math/big"
+
+type Test struct {
+ Int *big.Int
+ IntNoPtr big.Int
+}
diff --git a/helper/rlp/rlpgen/testdata/bigint.out.txt b/helper/rlp/rlpgen/testdata/bigint.out.txt
new file mode 100644
index 0000000000..226b005f8b
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/bigint.out.txt
@@ -0,0 +1,49 @@
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ if obj.Int == nil {
+ w.Write(rlp.EmptyString)
+ } else {
+ if obj.Int.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(obj.Int)
+ }
+ if obj.IntNoPtr.Sign() == -1 {
+ return rlp.ErrNegativeBigInt
+ }
+ w.WriteBigInt(&obj.IntNoPtr)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Int:
+ _tmp1, err := dec.BigInt()
+ if err != nil {
+ return err
+ }
+ _tmp0.Int = _tmp1
+ // IntNoPtr:
+ _tmp2, err := dec.BigInt()
+ if err != nil {
+ return err
+ }
+ _tmp0.IntNoPtr = (*_tmp2)
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/helper/rlp/rlpgen/testdata/nil.in.txt b/helper/rlp/rlpgen/testdata/nil.in.txt
new file mode 100644
index 0000000000..a28ff34487
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/nil.in.txt
@@ -0,0 +1,30 @@
+// -*- mode: go -*-
+
+package test
+
+type Aux struct{
+ A uint32
+}
+
+type Test struct{
+ Uint8 *byte `rlp:"nil"`
+ Uint8List *byte `rlp:"nilList"`
+
+ Uint32 *uint32 `rlp:"nil"`
+ Uint32List *uint32 `rlp:"nilList"`
+
+ Uint64 *uint64 `rlp:"nil"`
+ Uint64List *uint64 `rlp:"nilList"`
+
+ String *string `rlp:"nil"`
+ StringList *string `rlp:"nilList"`
+
+ ByteArray *[3]byte `rlp:"nil"`
+ ByteArrayList *[3]byte `rlp:"nilList"`
+
+ ByteSlice *[]byte `rlp:"nil"`
+ ByteSliceList *[]byte `rlp:"nilList"`
+
+ Struct *Aux `rlp:"nil"`
+ StructString *Aux `rlp:"nilString"`
+}
diff --git a/helper/rlp/rlpgen/testdata/nil.out.txt b/helper/rlp/rlpgen/testdata/nil.out.txt
new file mode 100644
index 0000000000..06c2147af6
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/nil.out.txt
@@ -0,0 +1,289 @@
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ if obj.Uint8 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint8)))
+ }
+ if obj.Uint8List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint8List)))
+ }
+ if obj.Uint32 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint32)))
+ }
+ if obj.Uint32List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64(uint64((*obj.Uint32List)))
+ }
+ if obj.Uint64 == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64((*obj.Uint64))
+ }
+ if obj.Uint64List == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteUint64((*obj.Uint64List))
+ }
+ if obj.String == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteString((*obj.String))
+ }
+ if obj.StringList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteString((*obj.StringList))
+ }
+ if obj.ByteArray == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteBytes(obj.ByteArray[:])
+ }
+ if obj.ByteArrayList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteBytes(obj.ByteArrayList[:])
+ }
+ if obj.ByteSlice == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteBytes((*obj.ByteSlice))
+ }
+ if obj.ByteSliceList == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ w.WriteBytes((*obj.ByteSliceList))
+ }
+ if obj.Struct == nil {
+ w.Write([]byte{0xC0})
+ } else {
+ _tmp1 := w.List()
+ w.WriteUint64(uint64(obj.Struct.A))
+ w.ListEnd(_tmp1)
+ }
+ if obj.StructString == nil {
+ w.Write([]byte{0x80})
+ } else {
+ _tmp2 := w.List()
+ w.WriteUint64(uint64(obj.StructString.A))
+ w.ListEnd(_tmp2)
+ }
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Uint8:
+ var _tmp2 *byte
+ if _tmp3, _tmp4, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp4 != 0 || _tmp3 != rlp.String {
+ _tmp1, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp2 = &_tmp1
+ }
+ _tmp0.Uint8 = _tmp2
+ // Uint8List:
+ var _tmp6 *byte
+ if _tmp7, _tmp8, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp8 != 0 || _tmp7 != rlp.List {
+ _tmp5, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp6 = &_tmp5
+ }
+ _tmp0.Uint8List = _tmp6
+ // Uint32:
+ var _tmp10 *uint32
+ if _tmp11, _tmp12, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp12 != 0 || _tmp11 != rlp.String {
+ _tmp9, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp10 = &_tmp9
+ }
+ _tmp0.Uint32 = _tmp10
+ // Uint32List:
+ var _tmp14 *uint32
+ if _tmp15, _tmp16, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp16 != 0 || _tmp15 != rlp.List {
+ _tmp13, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp14 = &_tmp13
+ }
+ _tmp0.Uint32List = _tmp14
+ // Uint64:
+ var _tmp18 *uint64
+ if _tmp19, _tmp20, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp20 != 0 || _tmp19 != rlp.String {
+ _tmp17, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp18 = &_tmp17
+ }
+ _tmp0.Uint64 = _tmp18
+ // Uint64List:
+ var _tmp22 *uint64
+ if _tmp23, _tmp24, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp24 != 0 || _tmp23 != rlp.List {
+ _tmp21, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp22 = &_tmp21
+ }
+ _tmp0.Uint64List = _tmp22
+ // String:
+ var _tmp26 *string
+ if _tmp27, _tmp28, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp28 != 0 || _tmp27 != rlp.String {
+ _tmp25, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp26 = &_tmp25
+ }
+ _tmp0.String = _tmp26
+ // StringList:
+ var _tmp30 *string
+ if _tmp31, _tmp32, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp32 != 0 || _tmp31 != rlp.List {
+ _tmp29, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp30 = &_tmp29
+ }
+ _tmp0.StringList = _tmp30
+ // ByteArray:
+ var _tmp34 *[3]byte
+ if _tmp35, _tmp36, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp36 != 0 || _tmp35 != rlp.String {
+ var _tmp33 [3]byte
+ if err := dec.ReadBytes(_tmp33[:]); err != nil {
+ return err
+ }
+ _tmp34 = &_tmp33
+ }
+ _tmp0.ByteArray = _tmp34
+ // ByteArrayList:
+ var _tmp38 *[3]byte
+ if _tmp39, _tmp40, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp40 != 0 || _tmp39 != rlp.List {
+ var _tmp37 [3]byte
+ if err := dec.ReadBytes(_tmp37[:]); err != nil {
+ return err
+ }
+ _tmp38 = &_tmp37
+ }
+ _tmp0.ByteArrayList = _tmp38
+ // ByteSlice:
+ var _tmp42 *[]byte
+ if _tmp43, _tmp44, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp44 != 0 || _tmp43 != rlp.String {
+ _tmp41, err := dec.Bytes()
+ if err != nil {
+ return err
+ }
+ _tmp42 = &_tmp41
+ }
+ _tmp0.ByteSlice = _tmp42
+ // ByteSliceList:
+ var _tmp46 *[]byte
+ if _tmp47, _tmp48, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp48 != 0 || _tmp47 != rlp.List {
+ _tmp45, err := dec.Bytes()
+ if err != nil {
+ return err
+ }
+ _tmp46 = &_tmp45
+ }
+ _tmp0.ByteSliceList = _tmp46
+ // Struct:
+ var _tmp51 *Aux
+ if _tmp52, _tmp53, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp53 != 0 || _tmp52 != rlp.List {
+ var _tmp49 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp50, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp49.A = _tmp50
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp51 = &_tmp49
+ }
+ _tmp0.Struct = _tmp51
+ // StructString:
+ var _tmp56 *Aux
+ if _tmp57, _tmp58, err := dec.Kind(); err != nil {
+ return err
+ } else if _tmp58 != 0 || _tmp57 != rlp.String {
+ var _tmp54 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp55, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp54.A = _tmp55
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp56 = &_tmp54
+ }
+ _tmp0.StructString = _tmp56
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/helper/rlp/rlpgen/testdata/optional.in.txt b/helper/rlp/rlpgen/testdata/optional.in.txt
new file mode 100644
index 0000000000..f1ac9f7899
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/optional.in.txt
@@ -0,0 +1,17 @@
+// -*- mode: go -*-
+
+package test
+
+type Aux struct {
+ A uint64
+}
+
+type Test struct {
+ Uint64 uint64 `rlp:"optional"`
+ Pointer *uint64 `rlp:"optional"`
+ String string `rlp:"optional"`
+ Slice []uint64 `rlp:"optional"`
+ Array [3]byte `rlp:"optional"`
+ NamedStruct Aux `rlp:"optional"`
+ AnonStruct struct{ A string } `rlp:"optional"`
+}
diff --git a/helper/rlp/rlpgen/testdata/optional.out.txt b/helper/rlp/rlpgen/testdata/optional.out.txt
new file mode 100644
index 0000000000..619640f5bd
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/optional.out.txt
@@ -0,0 +1,153 @@
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ _tmp1 := obj.Uint64 != 0
+ _tmp2 := obj.Pointer != nil
+ _tmp3 := obj.String != ""
+ _tmp4 := len(obj.Slice) > 0
+ _tmp5 := obj.Array != ([3]byte{})
+ _tmp6 := obj.NamedStruct != (Aux{})
+ _tmp7 := obj.AnonStruct != (struct{ A string }{})
+ if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ w.WriteUint64(obj.Uint64)
+ }
+ if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ if obj.Pointer == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.WriteUint64((*obj.Pointer))
+ }
+ }
+ if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ w.WriteString(obj.String)
+ }
+ if _tmp4 || _tmp5 || _tmp6 || _tmp7 {
+ _tmp8 := w.List()
+ for _, _tmp9 := range obj.Slice {
+ w.WriteUint64(_tmp9)
+ }
+ w.ListEnd(_tmp8)
+ }
+ if _tmp5 || _tmp6 || _tmp7 {
+ w.WriteBytes(obj.Array[:])
+ }
+ if _tmp6 || _tmp7 {
+ _tmp10 := w.List()
+ w.WriteUint64(obj.NamedStruct.A)
+ w.ListEnd(_tmp10)
+ }
+ if _tmp7 {
+ _tmp11 := w.List()
+ w.WriteString(obj.AnonStruct.A)
+ w.ListEnd(_tmp11)
+ }
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // Uint64:
+ if dec.MoreDataInList() {
+ _tmp1, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.Uint64 = _tmp1
+ // Pointer:
+ if dec.MoreDataInList() {
+ _tmp2, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.Pointer = &_tmp2
+ // String:
+ if dec.MoreDataInList() {
+ _tmp3, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp0.String = _tmp3
+ // Slice:
+ if dec.MoreDataInList() {
+ var _tmp4 []uint64
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ for dec.MoreDataInList() {
+ _tmp5, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp4 = append(_tmp4, _tmp5)
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ _tmp0.Slice = _tmp4
+ // Array:
+ if dec.MoreDataInList() {
+ var _tmp6 [3]byte
+ if err := dec.ReadBytes(_tmp6[:]); err != nil {
+ return err
+ }
+ _tmp0.Array = _tmp6
+ // NamedStruct:
+ if dec.MoreDataInList() {
+ var _tmp7 Aux
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp8, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp7.A = _tmp8
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp0.NamedStruct = _tmp7
+ // AnonStruct:
+ if dec.MoreDataInList() {
+ var _tmp9 struct{ A string }
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp10, err := dec.String()
+ if err != nil {
+ return err
+ }
+ _tmp9.A = _tmp10
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ _tmp0.AnonStruct = _tmp9
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/helper/rlp/rlpgen/testdata/rawvalue.in.txt b/helper/rlp/rlpgen/testdata/rawvalue.in.txt
new file mode 100644
index 0000000000..ee7c5aea26
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/rawvalue.in.txt
@@ -0,0 +1,11 @@
+// -*- mode: go -*-
+
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+
+type Test struct {
+ RawValue rlp.RawValue
+ PointerToRawValue *rlp.RawValue
+ SliceOfRawValue []rlp.RawValue
+}
diff --git a/helper/rlp/rlpgen/testdata/rawvalue.out.txt b/helper/rlp/rlpgen/testdata/rawvalue.out.txt
new file mode 100644
index 0000000000..07b6166924
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/rawvalue.out.txt
@@ -0,0 +1,64 @@
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ w.Write(obj.RawValue)
+ if obj.PointerToRawValue == nil {
+ w.Write([]byte{0x80})
+ } else {
+ w.Write((*obj.PointerToRawValue))
+ }
+ _tmp1 := w.List()
+ for _, _tmp2 := range obj.SliceOfRawValue {
+ w.Write(_tmp2)
+ }
+ w.ListEnd(_tmp1)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // RawValue:
+ _tmp1, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp0.RawValue = _tmp1
+ // PointerToRawValue:
+ _tmp2, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp0.PointerToRawValue = &_tmp2
+ // SliceOfRawValue:
+ var _tmp3 []rlp.RawValue
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ for dec.MoreDataInList() {
+ _tmp4, err := dec.Raw()
+ if err != nil {
+ return err
+ }
+ _tmp3 = append(_tmp3, _tmp4)
+ }
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ _tmp0.SliceOfRawValue = _tmp3
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/helper/rlp/rlpgen/testdata/uints.in.txt b/helper/rlp/rlpgen/testdata/uints.in.txt
new file mode 100644
index 0000000000..8095da997d
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/uints.in.txt
@@ -0,0 +1,10 @@
+// -*- mode: go -*-
+
+package test
+
+type Test struct{
+ A uint8
+ B uint16
+ C uint32
+ D uint64
+}
diff --git a/helper/rlp/rlpgen/testdata/uints.out.txt b/helper/rlp/rlpgen/testdata/uints.out.txt
new file mode 100644
index 0000000000..ed953fd706
--- /dev/null
+++ b/helper/rlp/rlpgen/testdata/uints.out.txt
@@ -0,0 +1,53 @@
+package test
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+import "io"
+
+func (obj *Test) EncodeRLP(_w io.Writer) error {
+ w := rlp.NewEncoderBuffer(_w)
+ _tmp0 := w.List()
+ w.WriteUint64(uint64(obj.A))
+ w.WriteUint64(uint64(obj.B))
+ w.WriteUint64(uint64(obj.C))
+ w.WriteUint64(obj.D)
+ w.ListEnd(_tmp0)
+ return w.Flush()
+}
+
+func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
+ var _tmp0 Test
+ {
+ if _, err := dec.List(); err != nil {
+ return err
+ }
+ // A:
+ _tmp1, err := dec.Uint8()
+ if err != nil {
+ return err
+ }
+ _tmp0.A = _tmp1
+ // B:
+ _tmp2, err := dec.Uint16()
+ if err != nil {
+ return err
+ }
+ _tmp0.B = _tmp2
+ // C:
+ _tmp3, err := dec.Uint32()
+ if err != nil {
+ return err
+ }
+ _tmp0.C = _tmp3
+ // D:
+ _tmp4, err := dec.Uint64()
+ if err != nil {
+ return err
+ }
+ _tmp0.D = _tmp4
+ if err := dec.ListEnd(); err != nil {
+ return err
+ }
+ }
+ *obj = _tmp0
+ return nil
+}
diff --git a/helper/rlp/rlpgen/types.go b/helper/rlp/rlpgen/types.go
new file mode 100644
index 0000000000..edeaad432a
--- /dev/null
+++ b/helper/rlp/rlpgen/types.go
@@ -0,0 +1,120 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package main
+
+import (
+ "fmt"
+ "go/types"
+ "reflect"
+)
+
+// typeReflectKind gives the reflect.Kind that represents typ.
+func typeReflectKind(typ types.Type) reflect.Kind {
+ switch typ := typ.(type) {
+ case *types.Basic:
+ k := typ.Kind()
+
+ switch {
+ case k >= types.Bool && k <= types.Complex128:
+ // value order matches for Bool..Complex128
+ return reflect.Bool + reflect.Kind(k-types.Bool)
+ case k == types.String:
+ return reflect.String
+ case k == types.UnsafePointer:
+ return reflect.UnsafePointer
+ }
+
+ panic(fmt.Errorf("unhandled BasicKind %v", k))
+ case *types.Array:
+ return reflect.Array
+ case *types.Chan:
+ return reflect.Chan
+ case *types.Interface:
+ return reflect.Interface
+ case *types.Map:
+ return reflect.Map
+ case *types.Pointer:
+ return reflect.Ptr
+ case *types.Signature:
+ return reflect.Func
+ case *types.Slice:
+ return reflect.Slice
+ case *types.Struct:
+ return reflect.Struct
+ default:
+ panic(fmt.Errorf("unhandled type %T", typ))
+ }
+}
+
+// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'.
+func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string {
+ // Resolve type name.
+ typ := resolveUnderlying(vtyp)
+ switch typ := typ.(type) {
+ case *types.Basic:
+ k := typ.Kind()
+
+ switch {
+ case k == types.Bool:
+ return v
+ case k >= types.Uint && k <= types.Complex128:
+ return fmt.Sprintf("%s != 0", v)
+ case k == types.String:
+ return fmt.Sprintf(`%s != ""`, v)
+ default:
+ panic(fmt.Errorf("unhandled BasicKind %v", k))
+ }
+ case *types.Array, *types.Struct:
+ return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify))
+ case *types.Interface, *types.Pointer, *types.Signature:
+ return fmt.Sprintf("%s != nil", v)
+ case *types.Slice, *types.Map:
+ return fmt.Sprintf("len(%s) > 0", v)
+ default:
+ panic(fmt.Errorf("unhandled type %T", typ))
+ }
+}
+
+// isBigInt checks whether 'typ' is "math/big".Int.
+func isBigInt(typ types.Type) bool {
+ named, ok := typ.(*types.Named)
+ if !ok {
+ return false
+ }
+
+ name := named.Obj()
+
+ return name.Pkg().Path() == "math/big" && name.Name() == "Int"
+}
+
+// isByte checks whether the underlying type of 'typ' is uint8.
+func isByte(typ types.Type) bool {
+ basic, ok := resolveUnderlying(typ).(*types.Basic)
+
+ return ok && basic.Kind() == types.Uint8
+}
+
+func resolveUnderlying(typ types.Type) types.Type {
+ for {
+ t := typ.Underlying()
+ if t == typ {
+ return t
+ }
+
+ typ = t
+ }
+}
diff --git a/helper/rlp/safe.go b/helper/rlp/safe.go
new file mode 100644
index 0000000000..3c910337b6
--- /dev/null
+++ b/helper/rlp/safe.go
@@ -0,0 +1,27 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+//go:build nacl || js || !cgo
+// +build nacl js !cgo
+
+package rlp
+
+import "reflect"
+
+// byteArrayBytes returns a slice of the byte array v.
+func byteArrayBytes(v reflect.Value, length int) []byte {
+ return v.Slice(0, length).Bytes()
+}
diff --git a/helper/rlp/typecache.go b/helper/rlp/typecache.go
new file mode 100644
index 0000000000..f1a9e67a96
--- /dev/null
+++ b/helper/rlp/typecache.go
@@ -0,0 +1,258 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package rlp
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "sync"
+ "sync/atomic"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp/internal/rlpstruct"
+)
+
+// typeinfo is an entry in the type cache.
+type typeinfo struct {
+ decoder decoder
+ decoderErr error // error from makeDecoder
+ writer writer
+ writerErr error // error from makeWriter
+}
+
+// typekey is the key of a type in typeCache. It includes the struct tags because
+// they might generate a different decoder.
+type typekey struct {
+ reflect.Type
+ rlpstruct.Tags
+}
+
+type decoder func(*Stream, reflect.Value) error
+
+type writer func(reflect.Value, *encBuffer) error
+
+var theTC = newTypeCache()
+
+type typeCache struct {
+ cur atomic.Value
+
+ // This lock synchronizes writers.
+ mu sync.Mutex
+ next map[typekey]*typeinfo
+}
+
+func newTypeCache() *typeCache {
+ c := new(typeCache)
+ c.cur.Store(make(map[typekey]*typeinfo))
+
+ return c
+}
+
+func cachedDecoder(typ reflect.Type) (decoder, error) {
+ info := theTC.info(typ)
+
+ return info.decoder, info.decoderErr
+}
+
+func cachedWriter(typ reflect.Type) (writer, error) {
+ info := theTC.info(typ)
+
+ return info.writer, info.writerErr
+}
+
+func (c *typeCache) info(typ reflect.Type) *typeinfo {
+ key := typekey{Type: typ}
+ //nolint:forcetypeassert
+ if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil {
+ return info
+ }
+
+ // Not in the cache, need to generate info for this type.
+ return c.generate(typ, rlpstruct.Tags{})
+}
+
+func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ cur, _ := c.cur.Load().(map[typekey]*typeinfo)
+ if info := cur[typekey{typ, tags}]; info != nil {
+ return info
+ }
+
+ // Copy cur to next.
+ c.next = make(map[typekey]*typeinfo, len(cur)+1)
+ for k, v := range cur {
+ c.next[k] = v
+ }
+
+ // Generate.
+ info := c.infoWhileGenerating(typ, tags)
+
+ // next -> cur
+ c.cur.Store(c.next)
+ c.next = nil
+
+ return info
+}
+
+func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo {
+ key := typekey{typ, tags}
+ if info := c.next[key]; info != nil {
+ return info
+ }
+ // Put a dummy value into the cache before generating.
+ // If the generator tries to lookup itself, it will get
+ // the dummy value and won't call itself recursively.
+ info := new(typeinfo)
+ c.next[key] = info
+ info.generate(typ, tags)
+
+ return info
+}
+
+type field struct {
+ index int
+ info *typeinfo
+ optional bool
+}
+
+// structFields resolves the typeinfo of all public fields in a struct type.
+func structFields(typ reflect.Type) (fields []field, err error) {
+ // Convert fields to rlpstruct.Field.
+ var allStructFields = make([]rlpstruct.Field, 0, typ.NumField())
+
+ for i := 0; i < typ.NumField(); i++ {
+ rf := typ.Field(i)
+
+ allStructFields = append(allStructFields, rlpstruct.Field{
+ Name: rf.Name,
+ Index: i,
+ Exported: rf.PkgPath == "",
+ Tag: string(rf.Tag),
+ Type: *rtypeToStructType(rf.Type, nil),
+ })
+ }
+
+ // Filter/validate fields.
+ structFields, structTags, err := rlpstruct.ProcessFields(allStructFields)
+ if err != nil {
+ var tagErr rlpstruct.TagError
+ if errors.As(err, &tagErr) {
+ tagErr.StructType = typ.String()
+
+ return nil, &tagErr
+ }
+
+ return nil, err
+ }
+
+ // Resolve typeinfo.
+ for i, sf := range structFields {
+ typ := typ.Field(sf.Index).Type
+ tags := structTags[i]
+ info := theTC.infoWhileGenerating(typ, tags)
+ fields = append(fields, field{sf.Index, info, tags.Optional})
+ }
+
+ return fields, nil
+}
+
+// firstOptionalField returns the index of the first field with "optional" tag.
+func firstOptionalField(fields []field) int {
+ for i, f := range fields {
+ if f.optional {
+ return i
+ }
+ }
+
+ return len(fields)
+}
+
+type structFieldError struct {
+ typ reflect.Type
+ field int
+ err error
+}
+
+func (e structFieldError) Error() string {
+ return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name)
+}
+
+func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) {
+ i.decoder, i.decoderErr = makeDecoder(typ, tags)
+ i.writer, i.writerErr = makeWriter(typ, tags)
+}
+
+// rtypeToStructType converts typ to rlpstruct.Type.
+func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type {
+ k := typ.Kind()
+ if k == reflect.Invalid {
+ panic("invalid kind")
+ }
+
+ if prev := rec[typ]; prev != nil {
+ return prev // short-circuit for recursive types
+ }
+
+ if rec == nil {
+ rec = make(map[reflect.Type]*rlpstruct.Type)
+ }
+
+ t := &rlpstruct.Type{
+ Name: typ.String(),
+ Kind: k,
+ IsEncoder: typ.Implements(encoderInterface),
+ IsDecoder: typ.Implements(decoderInterface),
+ }
+
+ rec[typ] = t
+ if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr {
+ t.Elem = rtypeToStructType(typ.Elem(), rec)
+ }
+
+ return t
+}
+
+// typeNilKind gives the RLP value kind for nil pointers to 'typ'.
+func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind {
+ styp := rtypeToStructType(typ, nil)
+
+ var nk rlpstruct.NilKind
+ if tags.NilOK {
+ nk = tags.NilKind
+ } else {
+ nk = styp.DefaultNilValue()
+ }
+
+ switch nk {
+ case rlpstruct.NilKindString:
+ return String
+ case rlpstruct.NilKindList:
+ return List
+ default:
+ panic("invalid nil kind value")
+ }
+}
+
+func isUint(k reflect.Kind) bool {
+ return k >= reflect.Uint && k <= reflect.Uintptr
+}
+
+func isByte(typ reflect.Type) bool {
+ return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
+}
diff --git a/helper/rlp/unsafe.go b/helper/rlp/unsafe.go
new file mode 100644
index 0000000000..ba8f290131
--- /dev/null
+++ b/helper/rlp/unsafe.go
@@ -0,0 +1,37 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+//go:build !nacl && !js && cgo
+// +build !nacl,!js,cgo
+
+package rlp
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// byteArrayBytes returns a slice of the byte array v.
+func byteArrayBytes(v reflect.Value, length int) []byte {
+ var s []byte
+
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s))
+ hdr.Data = v.UnsafeAddr()
+ hdr.Cap = length
+ hdr.Len = length
+
+ return s
+}
diff --git a/jsonrpc/eth_blockchain_test.go b/jsonrpc/eth_blockchain_test.go
index 3c7f2dd265..7fbe037563 100644
--- a/jsonrpc/eth_blockchain_test.go
+++ b/jsonrpc/eth_blockchain_test.go
@@ -120,7 +120,6 @@ func TestEth_GetTransactionByHash(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, res)
- //nolint:forcetypeassert
foundTxn := res.(*transaction)
assert.Equal(t, argUint64(testTxn.Nonce), foundTxn.Nonce)
assert.Equal(t, argUint64(block.Number()), *foundTxn.BlockNumber)
@@ -143,7 +142,6 @@ func TestEth_GetTransactionByHash(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, res)
- //nolint:forcetypeassert
foundTxn := res.(*transaction)
assert.Equal(t, argUint64(testTxn.Nonce), foundTxn.Nonce)
assert.Nil(t, foundTxn.BlockNumber)
@@ -196,7 +194,6 @@ func TestEth_GetTransactionReceipt(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, res)
- //nolint:forcetypeassert
response := res.(*receipt)
assert.Equal(t, txn.Hash(), response.TxHash)
assert.Equal(t, block.Hash(), response.BlockHash)
@@ -216,7 +213,6 @@ func TestEth_Syncing(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, res)
- //nolint:forcetypeassert
response := res.(progression)
assert.NotEqual(t, progress.ChainSyncBulk, response.Type)
assert.Equal(t, fmt.Sprintf("0x%x", 1), response.StartingBlock)
@@ -230,7 +226,6 @@ func TestEth_Syncing(t *testing.T) {
res, err := eth.Syncing()
assert.NoError(t, err)
- //nolint:forcetypeassert
assert.False(t, res.(bool))
})
}
@@ -244,7 +239,6 @@ func TestEth_GasPrice(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, res)
- //nolint:forcetypeassert
response := res.(string)
assert.Equal(t, fmt.Sprintf("0x%x", store.averageGasPrice), response)
}
@@ -451,7 +445,7 @@ func (m *mockBlockStore) ReadTxLookup(txnHash types.Hash) (types.Hash, bool) {
}
}
- return types.ZeroHash, false
+ return types.Hash{}, false
}
func (m *mockBlockStore) GetPendingTx(txHash types.Hash) (*types.Transaction, bool) {
diff --git a/jsonrpc/eth_endpoint.go b/jsonrpc/eth_endpoint.go
index 4d0684d347..7e1145401f 100644
--- a/jsonrpc/eth_endpoint.go
+++ b/jsonrpc/eth_endpoint.go
@@ -10,8 +10,8 @@ import (
"github.com/dogechain-lab/dogechain/helper/progress"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
- "github.com/dogechain-lab/fastrlp"
"github.com/hashicorp/go-hclog"
)
@@ -31,8 +31,8 @@ type ethTxPoolStore interface {
}
type ethStateStore interface {
- GetAccount(root types.Hash, addr types.Address) (*state.Account, error)
- GetStorage(root types.Hash, addr types.Address, slot types.Hash) ([]byte, error)
+ GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error)
+ GetStorage(stateRoot types.Hash, addr types.Address, slot types.Hash) (types.Hash, error)
GetForksInTime(blockNumber uint64) chain.ForksInTime
GetCode(stateRoot types.Hash, accoun types.Address) ([]byte, error)
}
@@ -414,7 +414,7 @@ func (e *Eth) GetTransactionReceipt(hash types.Hash) (interface{}, error) {
// GetStorageAt returns the contract storage at the index position
func (e *Eth) GetStorageAt(
address types.Address,
- index types.Hash,
+ slot types.Hash,
filter BlockNumberOrHash,
) (interface{}, error) {
e.metrics.EthAPICounterInc(EthGetStorageAtLabel)
@@ -435,30 +435,16 @@ func (e *Eth) GetStorageAt(
}
// Get the storage for the passed in location
- result, err := e.store.GetStorage(header.StateRoot, address, index)
+ result, err := e.store.GetStorage(header.StateRoot, address, slot)
if err != nil {
if errors.Is(err, ErrStateNotFound) {
- return argBytesPtr(types.ZeroHash[:]), nil
+ return argBytesPtr(result.Bytes()), nil
}
return nil, err
}
- // Parse the RLP value
- p := &fastrlp.Parser{}
- v, err := p.Parse(result)
- if err != nil {
- return argBytesPtr(types.ZeroHash[:]), nil
- }
-
- data, err := v.Bytes()
-
- if err != nil {
- return argBytesPtr(types.ZeroHash[:]), nil
- }
-
- // Pad to return 32 bytes data
- return argBytesPtr(types.BytesToHash(data).Bytes()), nil
+ return argBytesPtr(result.Bytes()), nil
}
// GasPrice returns the average gas price based on the last x blocks
diff --git a/jsonrpc/eth_endpoint_test.go b/jsonrpc/eth_endpoint_test.go
index 78e987f079..6538a8beed 100644
--- a/jsonrpc/eth_endpoint_test.go
+++ b/jsonrpc/eth_endpoint_test.go
@@ -5,7 +5,7 @@ import (
"math/big"
"testing"
- "github.com/dogechain-lab/dogechain/state"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
@@ -14,7 +14,7 @@ import (
func TestEth_DecodeTxn(t *testing.T) {
tests := []struct {
name string
- accounts map[types.Address]*state.Account
+ accounts map[types.Address]*stypes.Account
arg *txnArgs
res *types.Transaction
err error
@@ -75,7 +75,7 @@ func TestEth_DecodeTxn(t *testing.T) {
},
{
name: "should set latest nonce as default",
- accounts: map[types.Address]*state.Account{
+ accounts: map[types.Address]*stypes.Account{
addr1: {
Nonce: 10,
},
@@ -164,11 +164,11 @@ func TestEth_GetNextNonce(t *testing.T) {
// Set up the mock accounts
accounts := []struct {
address types.Address
- account *state.Account
+ account *stypes.Account
}{
{
types.StringToAddress("123"),
- &state.Account{
+ &stypes.Account{
Nonce: 5,
},
},
diff --git a/jsonrpc/eth_state_test.go b/jsonrpc/eth_state_test.go
index 1922e4a6de..535db769b5 100644
--- a/jsonrpc/eth_state_test.go
+++ b/jsonrpc/eth_state_test.go
@@ -10,6 +10,7 @@ import (
"github.com/dogechain-lab/dogechain/helper/hex"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
"github.com/dogechain-lab/fastrlp"
"github.com/stretchr/testify/assert"
@@ -25,14 +26,14 @@ func TestEth_State_GetBalance(t *testing.T) {
store := &mockSpecialStore{
account: &mockAccount{
address: addr0,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
},
storage: make(map[types.Hash][]byte),
},
block: &types.Block{
Header: &types.Header{
- Hash: types.ZeroHash,
+ Hash: types.Hash{},
Number: 0,
StateRoot: types.EmptyRootHash,
},
@@ -98,7 +99,7 @@ func TestEth_State_GetBalance(t *testing.T) {
addr0,
false,
nil,
- &types.ZeroHash,
+ &types.Hash{},
100,
},
{
@@ -157,7 +158,7 @@ func TestEth_State_GetTransactionCount(t *testing.T) {
store := &mockSpecialStore{
account: &mockAccount{
address: addr0,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
Nonce: 100,
},
@@ -165,7 +166,7 @@ func TestEth_State_GetTransactionCount(t *testing.T) {
},
block: &types.Block{
Header: &types.Header{
- Hash: types.ZeroHash,
+ Hash: types.Hash{},
Number: 0,
StateRoot: types.EmptyRootHash,
},
@@ -230,7 +231,7 @@ func TestEth_State_GetTransactionCount(t *testing.T) {
"should return valid nonce for valid block hash",
addr0,
nil,
- &types.ZeroHash,
+ &types.Hash{},
false,
100,
},
@@ -275,7 +276,7 @@ func TestEth_State_GetCode(t *testing.T) {
store := &mockSpecialStore{
account: &mockAccount{
address: addr0,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
Nonce: 100,
CodeHash: types.BytesToHash(addr0.Bytes()).Bytes(),
@@ -284,7 +285,7 @@ func TestEth_State_GetCode(t *testing.T) {
},
block: &types.Block{
Header: &types.Header{
- Hash: types.ZeroHash,
+ Hash: types.Hash{},
Number: 0,
StateRoot: types.EmptyRootHash,
},
@@ -351,7 +352,7 @@ func TestEth_State_GetCode(t *testing.T) {
"should return a valid code for valid block hash",
addr0,
nil,
- &types.ZeroHash,
+ &types.Hash{},
false,
code0,
},
@@ -400,7 +401,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
store := &mockSpecialStore{
account: &mockAccount{
address: addr0,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
Nonce: 100,
},
@@ -408,7 +409,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
block: &types.Block{
Header: &types.Header{
- Hash: types.ZeroHash,
+ Hash: types.Hash{},
Number: 0,
StateRoot: types.EmptyRootHash,
},
@@ -425,7 +426,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
name string
initialStorage map[types.Address]map[types.Hash]types.Hash
address types.Address
- index types.Hash
+ slot types.Hash
blockNumber *BlockNumber
blockHash *types.Hash
succeeded bool
@@ -439,7 +440,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash1,
+ slot: hash1,
blockNumber: nil,
blockHash: nil,
succeeded: true,
@@ -453,11 +454,11 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash2,
+ slot: hash2,
blockNumber: &blockNumberLatest,
blockHash: nil,
succeeded: true,
- expectedData: argBytesPtr(types.ZeroHash[:]),
+ expectedData: argBytesPtr(types.Hash{}.Bytes()),
},
{
name: "should return 32 bytes filled with zero for non-existing account",
@@ -467,10 +468,10 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash2,
+ slot: hash2,
blockNumber: &blockNumberLatest,
succeeded: true,
- expectedData: argBytesPtr(types.ZeroHash[:]),
+ expectedData: argBytesPtr(types.Hash{}.Bytes()),
},
{
name: "should return error for invalid block number",
@@ -480,7 +481,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash2,
+ slot: hash2,
blockNumber: &blockNumberInvalid,
blockHash: nil,
succeeded: false,
@@ -494,7 +495,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash1,
+ slot: hash1,
blockNumber: &blockNumberZero,
blockHash: nil,
succeeded: true,
@@ -508,9 +509,9 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash1,
+ slot: hash1,
blockNumber: nil,
- blockHash: &types.ZeroHash,
+ blockHash: &types.Hash{},
succeeded: true,
expectedData: argBytesPtr(hash1[:]),
},
@@ -522,7 +523,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash2,
+ slot: hash2,
blockNumber: nil,
blockHash: &hash1,
succeeded: false,
@@ -536,7 +537,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
},
},
address: addr0,
- index: hash1,
+ slot: hash1,
blockNumber: &blockNumberEarliest,
blockHash: nil,
succeeded: true,
@@ -549,7 +550,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
for addr, storage := range tt.initialStorage {
store.account = &mockAccount{
address: addr,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
Nonce: 100,
},
@@ -569,7 +570,7 @@ func TestEth_State_GetStorageAt(t *testing.T) {
BlockHash: tt.blockHash,
}
- res, err := eth.GetStorageAt(tt.address, tt.index, filter)
+ res, err := eth.GetStorageAt(tt.address, tt.slot, filter)
if tt.succeeded {
assert.NoError(t, err)
assert.NotNil(t, res)
@@ -597,7 +598,7 @@ func getExampleStore() *mockSpecialStore {
return &mockSpecialStore{
account: &mockAccount{
address: addr0,
- account: &state.Account{
+ account: &stypes.Account{
Balance: big.NewInt(100),
Nonce: 0,
},
@@ -781,7 +782,7 @@ func (m *mockSpecialStore) GetBlockByHash(hash types.Hash, full bool) (*types.Bl
return m.block, true
}
-func (m *mockSpecialStore) GetAccount(root types.Hash, addr types.Address) (*state.Account, error) {
+func (m *mockSpecialStore) GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error) {
if m.account.address != addr {
return nil, ErrStateNotFound
}
@@ -805,19 +806,19 @@ func (m *mockSpecialStore) GetNonce(addr types.Address) uint64 {
return 1
}
-func (m *mockSpecialStore) GetStorage(root types.Hash, addr types.Address, slot types.Hash) ([]byte, error) {
+func (m *mockSpecialStore) GetStorage(stateRoot types.Hash, addr types.Address, slot types.Hash) (types.Hash, error) {
if m.account.address != addr {
- return nil, ErrStateNotFound
+ return types.Hash{}, ErrStateNotFound
}
acct := m.account
- val, ok := acct.storage[slot]
+ val, ok := acct.storage[slot]
if !ok {
- return nil, ErrStateNotFound
+ return types.Hash{}, ErrStateNotFound
}
- return val, nil
+ return types.BytesToHash(val), nil
}
func (m *mockSpecialStore) GetCode(stateRoot types.Hash, addr types.Address) ([]byte, error) {
diff --git a/jsonrpc/eth_txpool_test.go b/jsonrpc/eth_txpool_test.go
index f73fad7404..650703dec5 100644
--- a/jsonrpc/eth_txpool_test.go
+++ b/jsonrpc/eth_txpool_test.go
@@ -5,7 +5,7 @@ import (
"testing"
"github.com/dogechain-lab/dogechain/helper/hex"
- "github.com/dogechain-lab/dogechain/state"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
"github.com/stretchr/testify/assert"
)
@@ -70,7 +70,7 @@ func (m *mockStoreTxn) AddAccount(addr types.Address) *mockAccount {
acct := &mockAccount{
address: addr,
- account: &state.Account{},
+ account: &stypes.Account{},
storage: make(map[types.Hash][]byte),
}
m.accounts[addr] = acct
@@ -82,7 +82,7 @@ func (m *mockStoreTxn) Header() *types.Header {
return &types.Header{}
}
-func (m *mockStoreTxn) GetAccount(root types.Hash, addr types.Address) (*state.Account, error) {
+func (m *mockStoreTxn) GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error) {
acct, ok := m.accounts[addr]
if !ok {
return nil, ErrStateNotFound
diff --git a/jsonrpc/jsonrpc.go b/jsonrpc/jsonrpc.go
index a3906e8eb0..f74cf8ec87 100644
--- a/jsonrpc/jsonrpc.go
+++ b/jsonrpc/jsonrpc.go
@@ -45,6 +45,7 @@ type JSONRPC struct {
config *Config
dispatcher dispatcher
metrics *Metrics
+ server *http.Server
}
type dispatcher interface {
@@ -103,6 +104,17 @@ func NewJSONRPC(logger hclog.Logger, config *Config) (*JSONRPC, error) {
return srv, nil
}
+func (j *JSONRPC) Close() error {
+ if j.server == nil {
+ return nil
+ }
+
+ err := j.server.Close()
+ j.server = nil
+
+ return err
+}
+
func (j *JSONRPC) setupHTTP() error {
j.logger.Info("http server started", "addr", j.config.Addr.String())
@@ -131,7 +143,7 @@ func (j *JSONRPC) setupHTTP() error {
mux.HandleFunc("/ws", j.handleWs)
}
- srv := http.Server{
+ srv := &http.Server{
Handler: mux,
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
@@ -139,6 +151,8 @@ func (j *JSONRPC) setupHTTP() error {
IdleTimeout: 2 * time.Minute,
}
+ j.server = srv
+
go func() {
if err := srv.Serve(lis); err != nil {
j.logger.Error("closed http connection", "err", err)
diff --git a/jsonrpc/mocks_test.go b/jsonrpc/mocks_test.go
index 7fd6bd4279..b364f6c8a3 100644
--- a/jsonrpc/mocks_test.go
+++ b/jsonrpc/mocks_test.go
@@ -5,14 +5,14 @@ import (
"sync"
"github.com/dogechain-lab/dogechain/blockchain"
- "github.com/dogechain-lab/dogechain/state"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
)
type mockAccount struct {
address types.Address
code []byte
- account *state.Account
+ account *stypes.Account
storage map[types.Hash][]byte
}
@@ -51,14 +51,14 @@ type mockStore struct {
subscription *blockchain.MockSubscription
receiptsLock sync.Mutex
receipts map[types.Hash][]*types.Receipt
- accounts map[types.Address]*state.Account
+ accounts map[types.Address]*stypes.Account
}
func newMockStore() *mockStore {
return &mockStore{
header: &types.Header{Number: 0},
subscription: blockchain.NewMockSubscription(),
- accounts: map[types.Address]*state.Account{},
+ accounts: map[types.Address]*stypes.Account{},
}
}
@@ -88,7 +88,7 @@ func (m *mockStore) emitEvent(evnt *mockEvent) {
m.subscription.Push(bEvnt)
}
-func (m *mockStore) GetAccount(root types.Hash, addr types.Address) (*state.Account, error) {
+func (m *mockStore) GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error) {
if acc, ok := m.accounts[addr]; ok {
return acc, nil
}
@@ -96,7 +96,7 @@ func (m *mockStore) GetAccount(root types.Hash, addr types.Address) (*state.Acco
return nil, ErrStateNotFound
}
-func (m *mockStore) SetAccount(addr types.Address, account *state.Account) {
+func (m *mockStore) SetAccount(addr types.Address, account *stypes.Account) {
m.accounts[addr] = account
}
diff --git a/jsonrpc/txpool_endpoint.go b/jsonrpc/txpool_endpoint.go
index cedd05c346..4a8a9d7b20 100644
--- a/jsonrpc/txpool_endpoint.go
+++ b/jsonrpc/txpool_endpoint.go
@@ -67,7 +67,7 @@ func toTxPoolTransaction(t *types.Transaction) *txpoolTransaction {
Input: t.Input,
Hash: t.Hash(),
From: t.From,
- BlockHash: types.ZeroHash,
+ BlockHash: types.Hash{},
BlockNumber: nil,
TxIndex: nil,
}
diff --git a/jsonrpc/txpool_endpoint_test.go b/jsonrpc/txpool_endpoint_test.go
index ef32d01da5..1ddd7622e2 100644
--- a/jsonrpc/txpool_endpoint_test.go
+++ b/jsonrpc/txpool_endpoint_test.go
@@ -16,7 +16,6 @@ func TestContentEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Content()
- //nolint:forcetypeassert
response := result.(ContentResponse)
assert.True(t, mockStore.includeQueued)
@@ -24,7 +23,6 @@ func TestContentEndpoint(t *testing.T) {
assert.Equal(t, 0, len(response.Queued))
})
- //nolint:dupl
t.Run("returns correct data for pending transaction", func(t *testing.T) {
mockStore := newMockTxPoolStore()
address1 := types.Address{0x1}
@@ -33,7 +31,6 @@ func TestContentEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Content()
- //nolint:forcetypeassert
response := result.(ContentResponse)
assert.Equal(t, 1, len(response.Pending))
@@ -52,7 +49,6 @@ func TestContentEndpoint(t *testing.T) {
assert.Equal(t, nil, txData.TxIndex)
})
- //nolint:dupl
t.Run("returns correct data for queued transaction", func(t *testing.T) {
mockStore := newMockTxPoolStore()
address1 := types.Address{0x1}
@@ -61,7 +57,6 @@ func TestContentEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Content()
- //nolint:forcetypeassert
response := result.(ContentResponse)
assert.Equal(t, 0, len(response.Pending))
@@ -96,7 +91,6 @@ func TestContentEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Content()
- //nolint:forcetypeassert
response := result.(ContentResponse)
assert.True(t, mockStore.includeQueued)
@@ -114,7 +108,6 @@ func TestInspectEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Inspect()
- //nolint:forcetypeassert
response := result.(InspectResponse)
assert.True(t, mockStore.includeQueued)
@@ -133,7 +126,6 @@ func TestInspectEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Inspect()
- //nolint:forcetypeassert
response := result.(InspectResponse)
assert.Equal(t, 0, len(response.Pending))
@@ -154,7 +146,6 @@ func TestInspectEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Inspect()
- //nolint:forcetypeassert
response := result.(InspectResponse)
assert.Equal(t, 1, len(response.Pending))
@@ -173,7 +164,6 @@ func TestStatusEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Status()
- //nolint:forcetypeassert
response := result.(StatusResponse)
assert.Equal(t, uint64(0), response.Pending)
@@ -196,7 +186,6 @@ func TestStatusEndpoint(t *testing.T) {
txPoolEndpoint := &TxPool{mockStore, NilMetrics()}
result, _ := txPoolEndpoint.Status()
- //nolint:forcetypeassert
response := result.(StatusResponse)
assert.Equal(t, uint64(3), response.Pending)
diff --git a/reverify/chain.go b/reverify/chain.go
index ed5a279f34..2bf8385f54 100644
--- a/reverify/chain.go
+++ b/reverify/chain.go
@@ -3,17 +3,18 @@ package reverify
import (
"context"
"fmt"
- "path/filepath"
"github.com/dogechain-lab/dogechain/blockchain"
"github.com/dogechain-lab/dogechain/blockchain/storage/kvstorage"
"github.com/dogechain-lab/dogechain/chain"
"github.com/dogechain-lab/dogechain/consensus"
"github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/network"
"github.com/dogechain-lab/dogechain/secrets"
"github.com/dogechain-lab/dogechain/server"
"github.com/dogechain-lab/dogechain/state"
+ "github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
@@ -21,15 +22,6 @@ import (
"github.com/dogechain-lab/dogechain/state/runtime/precompiled"
)
-func newLevelDBBuilder(log hclog.Logger, path string) kvdb.LevelDBBuilder {
- leveldbBuilder := kvdb.NewLevelDBBuilder(
- log,
- path,
- )
-
- return leveldbBuilder
-}
-
func createConsensus(
logger hclog.Logger,
genesis *chain.Chain,
@@ -38,8 +30,8 @@ func createConsensus(
dataDir string,
) (consensus.Consensus, error) {
engineName := genesis.Params.GetEngine()
- engine, ok := server.GetConsensusBackend(engineName)
+ engine, ok := server.GetConsensusBackend(engineName)
if !ok {
return nil, fmt.Errorf("consensus engine '%s' not found", engineName)
}
@@ -72,7 +64,7 @@ func createConsensus(
config := &consensus.Config{
Params: genesis.Params,
Config: engineConfig,
- Path: filepath.Join(dataDir, "consensus"),
+ Path: consensusDir(dataDir),
}
consensus, err := engine(
@@ -85,7 +77,7 @@ func createConsensus(
Blockchain: blockchain,
Executor: executor,
Grpc: nil,
- Logger: logger.Named("consensus"),
+ Logger: logger.Named(_consensusDir),
Metrics: nil,
SecretsManager: secretsManager,
BlockTime: 2,
@@ -101,12 +93,20 @@ func createConsensus(
}
func createBlockchain(
+ hub *DBHub,
+ db kvdb.KVBatchStorage,
logger hclog.Logger,
genesis *chain.Chain,
st itrie.StateDB,
dataDir string,
) (*blockchain.Blockchain, consensus.Consensus, error) {
- executor := state.NewExecutor(genesis.Params, st, logger)
+ // do we need snapshots here?
+ executor := state.NewExecutor(
+ genesis.Params,
+ logger,
+ st,
+ )
+
executor.SetRuntime(precompiled.NewPrecompiled())
executor.SetRuntime(evm.NewEVM())
@@ -120,11 +120,8 @@ func createBlockchain(
chain, err := blockchain.NewBlockchain(
logger,
genesis,
- kvstorage.NewLevelDBStorageBuilder(
- logger,
- newLevelDBBuilder(logger, filepath.Join(dataDir, "blockchain")),
- ),
nil,
+ kvstorage.NewKeyValueStorage(db),
executor,
nil,
)
@@ -132,7 +129,7 @@ func createBlockchain(
return nil, nil, err
}
- executor.GetHash = chain.GetHashHelper
+ executor.GetHash = hub.GetHashHelper
consensus, err := createConsensus(logger, genesis, chain, executor, dataDir)
if err != nil {
@@ -156,3 +153,15 @@ func createBlockchain(
return chain, consensus, nil
}
+
+type DBHub struct {
+ chainDB kvdb.KVBatchStorage
+}
+
+func (d *DBHub) GetHashHelper(header *types.Header) func(uint64) types.Hash {
+ return func(u uint64) types.Hash {
+ v, _ := rawdb.ReadCanonicalHash(d.chainDB, u)
+
+ return v
+ }
+}
diff --git a/reverify/reverify.go b/reverify/reverify.go
index bc48fa496b..4518bbedfe 100644
--- a/reverify/reverify.go
+++ b/reverify/reverify.go
@@ -7,28 +7,62 @@ import (
"github.com/hashicorp/go-hclog"
"github.com/dogechain-lab/dogechain/chain"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/leveldb"
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
)
+const (
+ _blockchainDir = "blockchain"
+ _stateDir = "trie"
+ _consensusDir = "consensus"
+)
+
+func stateDir(dataDir string) string {
+ return filepath.Join(dataDir, _stateDir)
+}
+
+func blockchainDir(dataDir string) string {
+ return filepath.Join(dataDir, _blockchainDir)
+}
+
+func consensusDir(dataDir string) string {
+ return filepath.Join(dataDir, _consensusDir)
+}
+
func ReverifyChain(
logger hclog.Logger,
chain *chain.Chain,
dataDir string,
startHeight uint64,
) error {
- stateStorage, err := itrie.NewLevelDBStorage(
- newLevelDBBuilder(logger, filepath.Join(dataDir, "trie")))
+ chainDB, err := leveldb.New(
+ blockchainDir(dataDir),
+ leveldb.SetLogger(logger),
+ )
if err != nil {
- logger.Error("failed to create state storage")
+ return err
+ }
+
+ defer chainDB.Close()
+
+ hub := &DBHub{chainDB: chainDB}
+ trieDB, err := leveldb.New(
+ stateDir(dataDir),
+ leveldb.SetLogger(logger),
+ )
+ if err != nil {
return err
}
- defer stateStorage.Close()
+
+ defer trieDB.Close()
blockchain, consensus, err := createBlockchain(
+ hub,
+ chainDB,
logger,
chain,
- itrie.NewStateDB(stateStorage, hclog.NewNullLogger(), itrie.NilMetrics()),
+ itrie.NewStateDB(trieDB, hclog.NewNullLogger(), itrie.NilMetrics()),
dataDir,
)
if err != nil {
@@ -36,8 +70,11 @@ func ReverifyChain(
return err
}
- defer blockchain.Close()
- defer consensus.Close()
+
+ defer func() {
+ consensus.Close()
+ blockchain.Close()
+ }()
hash, ok := blockchain.GetHeaderHash()
if ok {
diff --git a/server/config.go b/server/config.go
index e85a4d31ce..2fd6bcce09 100644
--- a/server/config.go
+++ b/server/config.go
@@ -2,6 +2,7 @@ package server
import (
"net"
+ "time"
"github.com/hashicorp/go-hclog"
@@ -10,10 +11,14 @@ import (
"github.com/dogechain-lab/dogechain/secrets"
)
-const DefaultGRPCPort int = 9632
-const DefaultJSONRPCPort int = 8545
-const DefaultGraphQLPort int = 9898
-const DefaultPprofPort int = 6060
+const (
+ DefaultGRPCPort int = 9632
+ DefaultJSONRPCPort int = 8545
+ DefaultGraphQLPort int = 9898
+ DefaultPprofPort int = 6060
+
+ TriesInMemory = 128
+)
// Config is used to parametrize the minimal client
type Config struct {
@@ -49,6 +54,11 @@ type Config struct {
ValidatorKey string
BlockBroadcast bool
+
+ // enable snapshots, disable by default
+ EnableSnapshot bool
+ // cache config, mostly for snapshots
+ CacheConfig *CacheConfig
}
// LeveldbOptions holds the leveldb options
@@ -84,3 +94,18 @@ type GraphQL struct {
BlockRangeLimit uint64
EnablePprof bool
}
+
+// CacheConfig contains the configuration values for the trie database
+// that's resident in a blockchain.
+type CacheConfig struct {
+ TrieCleanLimit int // Memory allowance (MB) to use for caching trie nodes in memory
+ TrieCleanJournal string // Disk journal for saving clean cache entries.
+ TrieCleanRejournal time.Duration // Time interval to dump clean cache to disk periodically
+ TrieDirtyLimit int // Memory limit (MB) at which to start flushing dirty trie nodes to disk
+ TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk
+ SnapshotLimit int // Memory allowance (MB) to use for caching snapshot entries in memory
+
+ SnapshotNoBuild bool // Whether the background generation is allowed
+ // Wait for snapshot construction on startup.
+ SnapshotWait bool
+}
diff --git a/server/jsonrpc_store.go b/server/jsonrpc_store.go
index 993b490447..10cf987851 100644
--- a/server/jsonrpc_store.go
+++ b/server/jsonrpc_store.go
@@ -1,6 +1,7 @@
package server
import (
+ "bytes"
"errors"
"fmt"
"math/big"
@@ -13,6 +14,8 @@ import (
"github.com/dogechain-lab/dogechain/network"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/txpool"
"github.com/dogechain-lab/dogechain/types"
)
@@ -26,11 +29,13 @@ type jsonRPCStore struct {
consensus consensus.Consensus
server network.Server
state state.State
+ snaps *snapshot.Tree
- metrics *JSONRPCStoreMetrics
+ metrics *jsonrpcStoreMetrics
}
func NewJSONRPCStore(
+ snaps *snapshot.Tree,
state state.State,
blockchain *blockchain.Blockchain,
restoreProgression *progress.ProgressionWrapper,
@@ -38,7 +43,7 @@ func NewJSONRPCStore(
executor *state.Executor,
consensus consensus.Consensus,
network network.Server,
- metrics *JSONRPCStoreMetrics,
+ metrics *jsonrpcStoreMetrics,
) jsonrpc.JSONRPCStore {
if metrics == nil {
metrics = JSONRPCStoreNilMetrics()
@@ -52,6 +57,7 @@ func NewJSONRPCStore(
consensus: consensus,
server: network,
state: state,
+ snaps: snaps,
metrics: metrics,
}
}
@@ -80,32 +86,16 @@ func (j *jsonRPCStore) GetPendingTx(txHash types.Hash) (*types.Transaction, bool
}
// jsonrpc.ethStateStore interface
-func (j *jsonRPCStore) GetAccount(root types.Hash, addr types.Address) (*state.Account, error) {
+func (j *jsonRPCStore) GetAccount(stateRoot types.Hash, addr types.Address) (*stypes.Account, error) {
j.metrics.GetAccountInc()
- return getAccountImpl(j.state, root, addr)
+ return getCommittedAccount(j.snaps, j.state, stateRoot, addr)
}
-func (j *jsonRPCStore) GetStorage(root types.Hash, addr types.Address, slot types.Hash) ([]byte, error) {
+func (j *jsonRPCStore) GetStorage(stateRoot types.Hash, addr types.Address, slot types.Hash) (types.Hash, error) {
j.metrics.GetStorageInc()
- account, err := getAccountImpl(j.state, root, addr)
- if err != nil {
- return nil, err
- }
-
- // make a snapshot at root
- snap, err := j.state.NewSnapshotAt(root)
- if err != nil {
- return nil, err
- }
-
- resp, err := snap.GetStorage(addr, account.Root, slot)
- if err != nil {
- return nil, err
- }
-
- return resp.Bytes(), nil
+ return getCommittedStorage(j.snaps, j.state, stateRoot, addr, slot)
}
// GetForksInTime returns the active forks at the given block height
@@ -119,11 +109,15 @@ func (j *jsonRPCStore) GetForksInTime(blockNumber uint64) chain.ForksInTime {
func (j *jsonRPCStore) GetCode(root types.Hash, addr types.Address) ([]byte, error) {
j.metrics.GetCodeInc()
- account, err := getAccountImpl(j.state, root, addr)
+ account, err := getCommittedAccount(j.snaps, j.state, root, addr)
if err != nil {
return nil, err
}
+ if len(account.CodeHash) == 0 || bytes.Equal(account.CodeHash, types.EmptyRootHash[:]) {
+ return []byte{}, nil
+ }
+
code, ok := j.state.GetCode(types.BytesToHash(account.CodeHash))
if !ok {
return nil, fmt.Errorf("unable to fetch code")
diff --git a/server/jsonrpc_store_metrics.go b/server/jsonrpc_store_metrics.go
index 96db859a5e..48e953333c 100644
--- a/server/jsonrpc_store_metrics.go
+++ b/server/jsonrpc_store_metrics.go
@@ -5,169 +5,169 @@ import (
"github.com/prometheus/client_golang/prometheus"
)
-type JSONRPCStoreMetrics struct {
+type jsonrpcStoreMetrics struct {
counter *prometheus.CounterVec
}
// GetNonce api calls
-func (m *JSONRPCStoreMetrics) GetNonceInc() {
+func (m *jsonrpcStoreMetrics) GetNonceInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetNonce"}).Inc()
}
}
// AddTx api calls
-func (m *JSONRPCStoreMetrics) AddTxInc() {
+func (m *jsonrpcStoreMetrics) AddTxInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "AddTx"}).Inc()
}
}
// GetPendingTx api calls
-func (m *JSONRPCStoreMetrics) GetPendingTxInc() {
+func (m *jsonrpcStoreMetrics) GetPendingTxInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetPendingTx"}).Inc()
}
}
// GetAccount api calls
-func (m *JSONRPCStoreMetrics) GetAccountInc() {
+func (m *jsonrpcStoreMetrics) GetAccountInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetAccount"}).Inc()
}
}
// GetGetStorage api calls
-func (m *JSONRPCStoreMetrics) GetStorageInc() {
+func (m *jsonrpcStoreMetrics) GetStorageInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetStorage"}).Inc()
}
}
// GetForksInTime api calls
-func (m *JSONRPCStoreMetrics) GetForksInTimeInc() {
+func (m *jsonrpcStoreMetrics) GetForksInTimeInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetForksInTime"}).Inc()
}
}
// GetCode api calls
-func (m *JSONRPCStoreMetrics) GetCodeInc() {
+func (m *jsonrpcStoreMetrics) GetCodeInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetCode"}).Inc()
}
}
// Header api calls
-func (m *JSONRPCStoreMetrics) HeaderInc() {
+func (m *jsonrpcStoreMetrics) HeaderInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "Header"}).Inc()
}
}
// GetHeaderByNumber api calls
-func (m *JSONRPCStoreMetrics) GetHeaderByNumberInc() {
+func (m *jsonrpcStoreMetrics) GetHeaderByNumberInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetHeaderByNumber"}).Inc()
}
}
// GetHeaderByHash api calls
-func (m *JSONRPCStoreMetrics) GetHeaderByHashInc() {
+func (m *jsonrpcStoreMetrics) GetHeaderByHashInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetHeaderByHash"}).Inc()
}
}
// GetBlockByHash api calls
-func (m *JSONRPCStoreMetrics) GetBlockByHashInc() {
+func (m *jsonrpcStoreMetrics) GetBlockByHashInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetBlockByHash"}).Inc()
}
}
// GetBlockByNumber api calls
-func (m *JSONRPCStoreMetrics) GetBlockByNumberInc() {
+func (m *jsonrpcStoreMetrics) GetBlockByNumberInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetBlockByNumber"}).Inc()
}
}
// ReadTxLookup api calls
-func (m *JSONRPCStoreMetrics) ReadTxLookupInc() {
+func (m *jsonrpcStoreMetrics) ReadTxLookupInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "ReadTxLookup"}).Inc()
}
}
// GetReceiptsByHash api calls
-func (m *JSONRPCStoreMetrics) GetReceiptsByHashInc() {
+func (m *jsonrpcStoreMetrics) GetReceiptsByHashInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetReceiptsByHash"}).Inc()
}
}
// GetAvgGasPrice api calls
-func (m *JSONRPCStoreMetrics) GetAvgGasPriceInc() {
+func (m *jsonrpcStoreMetrics) GetAvgGasPriceInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetAvgGasPrice"}).Inc()
}
}
// ApplyTxn api calls
-func (m *JSONRPCStoreMetrics) ApplyTxnInc() {
+func (m *jsonrpcStoreMetrics) ApplyTxnInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "ApplyTxn"}).Inc()
}
}
// GetSyncProgression api calls
-func (m *JSONRPCStoreMetrics) GetSyncProgressionInc() {
+func (m *jsonrpcStoreMetrics) GetSyncProgressionInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetSyncProgression"}).Inc()
}
}
// StateAtTransaction api calls
-func (m *JSONRPCStoreMetrics) StateAtTransactionInc() {
+func (m *jsonrpcStoreMetrics) StateAtTransactionInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "StateAtTransaction"}).Inc()
}
}
// PeerCount api calls
-func (m *JSONRPCStoreMetrics) PeerCountInc() {
+func (m *jsonrpcStoreMetrics) PeerCountInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "PeerCount"}).Inc()
}
}
// GetTxs api calls
-func (m *JSONRPCStoreMetrics) GetTxsInc() {
+func (m *jsonrpcStoreMetrics) GetTxsInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetTxs"}).Inc()
}
}
// GetCapacity api calls
-func (m *JSONRPCStoreMetrics) GetCapacityInc() {
+func (m *jsonrpcStoreMetrics) GetCapacityInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "GetCapacity"}).Inc()
}
}
// SubscribeEvents api calls
-func (m *JSONRPCStoreMetrics) SubscribeEventsInc() {
+func (m *jsonrpcStoreMetrics) SubscribeEventsInc() {
if m.counter != nil {
m.counter.With(prometheus.Labels{"method": "SubscribeEvents"}).Inc()
}
}
// NewJSONRPCStoreMetrics return the JSONRPCStore metrics instance
-func NewJSONRPCStoreMetrics(namespace string, labelsWithValues ...string) *JSONRPCStoreMetrics {
+func NewJSONRPCStoreMetrics(namespace string, labelsWithValues ...string) *jsonrpcStoreMetrics {
constLabels := metrics.ParseLables(labelsWithValues...)
- m := &JSONRPCStoreMetrics{
+ m := &jsonrpcStoreMetrics{
counter: prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "jsonrpc_store",
@@ -183,8 +183,8 @@ func NewJSONRPCStoreMetrics(namespace string, labelsWithValues ...string) *JSONR
}
// JSONRPCStoreNilMetrics will return the non operational jsonrpc metrics
-func JSONRPCStoreNilMetrics() *JSONRPCStoreMetrics {
- return &JSONRPCStoreMetrics{
+func JSONRPCStoreNilMetrics() *jsonrpcStoreMetrics {
+ return &jsonrpcStoreMetrics{
counter: nil,
}
}
diff --git a/server/server.go b/server/server.go
index 0e8c22b986..03c6e9839a 100644
--- a/server/server.go
+++ b/server/server.go
@@ -20,7 +20,9 @@ import (
"github.com/dogechain-lab/dogechain/graphql"
"github.com/dogechain-lab/dogechain/helper/common"
"github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/leveldb"
"github.com/dogechain-lab/dogechain/helper/progress"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/jsonrpc"
"github.com/dogechain-lab/dogechain/network"
"github.com/dogechain-lab/dogechain/secrets"
@@ -29,6 +31,10 @@ import (
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
"github.com/dogechain-lab/dogechain/state/runtime/evm"
"github.com/dogechain-lab/dogechain/state/runtime/precompiled"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/state/utils"
+ "github.com/dogechain-lab/dogechain/trie"
"github.com/dogechain-lab/dogechain/txpool"
"github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
@@ -39,23 +45,34 @@ import (
// Minimal is the central manager of the blockchain client
type Server struct {
- logger hclog.Logger
- config *Config
- state state.State
- stateStorage itrie.Storage
+ // configuration
+ config *Config
+ // global logger
+ logger hclog.Logger
+
+ // databases
+ trieDB itrie.Storage
+ chainDB kvdb.Database
+ snpTrieDB *trie.Database // snapshots usage trie database
+
+ // consensus
consensus consensus.Consensus
// blockchain stack
blockchain *blockchain.Blockchain
- chain *chain.Chain
+ // state holder
+ state state.State
// state executor
executor *state.Executor
+ // state snapshots
+ snaps *snapshot.Tree
+ // Cache configuration for pruning
+ cacheConfig *CacheConfig
// jsonrpc stack
jsonrpcServer *jsonrpc.JSONRPC
-
// graphql stack
graphqlServer *graphql.GraphQLService
@@ -80,12 +97,16 @@ type Server struct {
}
const (
- loggerDomainName = "dogechain"
+ loggerDomainName = "dogechain"
+ BlockchainDataDir = "blockchain"
+ StateDataDir = "trie"
+ TrieCacheDir = "triecache"
)
var dirPaths = []string{
- "blockchain",
- "trie",
+ BlockchainDataDir,
+ StateDataDir,
+ TrieCacheDir,
}
// newFileLogger returns logger instance that writes all logs to a specified file.
@@ -133,23 +154,6 @@ func newLoggerFromConfig(config *Config) (hclog.Logger, error) {
return newCLILogger(config), nil
}
-func newLevelDBBuilder(logger hclog.Logger, config *Config, path string) kvdb.LevelDBBuilder {
- leveldbBuilder := kvdb.NewLevelDBBuilder(
- logger,
- path,
- )
-
- // trie cache + blockchain cache = config.LeveldbOptions.CacheSize / 2
- leveldbBuilder.SetCacheSize(config.LeveldbOptions.CacheSize / 2).
- SetHandles(config.LeveldbOptions.Handles).
- SetBloomKeyBits(config.LeveldbOptions.BloomKeyBits).
- SetCompactionTableSize(config.LeveldbOptions.CompactionTableSize).
- SetCompactionTotalSize(config.LeveldbOptions.CompactionTotalSize).
- SetNoSync(config.LeveldbOptions.NoSync)
-
- return leveldbBuilder
-}
-
// NewServer creates a new Minimal server, using the passed in configuration
func NewServer(config *Config) (*Server, error) {
logger, err := newLoggerFromConfig(config)
@@ -157,10 +161,12 @@ func NewServer(config *Config) (*Server, error) {
return nil, fmt.Errorf("could not setup new logger instance, %w", err)
}
- m := &Server{
- logger: logger,
- config: config,
- chain: config.Chain,
+ hclog.SetDefault(logger)
+
+ srv := &Server{
+ logger: logger,
+ config: config,
+ cacheConfig: config.CacheConfig,
grpcServer: grpc.NewServer(
grpc.MaxRecvMsgSize(common.MaxGrpcMsgSize),
grpc.MaxSendMsgSize(common.MaxGrpcMsgSize),
@@ -168,185 +174,401 @@ func NewServer(config *Config) (*Server, error) {
restoreProgression: progress.NewProgressionWrapper(progress.ChainSyncRestore),
}
- m.logger.Info("Data dir", "path", config.DataDir)
+ srv.logger.Info("Data dir", "path", config.DataDir)
// Generate all the paths in the dataDir
if err := common.SetupDataDir(config.DataDir, dirPaths); err != nil {
return nil, fmt.Errorf("failed to create data directories: %w", err)
}
- if config.Telemetry.PrometheusAddr != nil {
- m.serverMetrics = metricProvider("dogechain", config.Chain.Name, true, config.Telemetry.EnableIOMetrics)
- m.prometheusServer = m.startPrometheusServer(config.Telemetry.PrometheusAddr)
- } else {
- m.serverMetrics = metricProvider("dogechain", config.Chain.Name, false, false)
- }
+ // Set up metrics
+ srv.setupMetris()
// Set up the secrets manager
- if err := m.setupSecretsManager(); err != nil {
+ if err := srv.setupSecretsManager(); err != nil {
return nil, fmt.Errorf("failed to set up the secrets manager: %w", err)
}
// start libp2p
- {
- netConfig := config.Network
- netConfig.Chain = m.config.Chain
- netConfig.DataDir = filepath.Join(m.config.DataDir, "libp2p")
- netConfig.SecretsManager = m.secretsManager
- netConfig.Metrics = m.serverMetrics.network
-
- network, err := network.NewServer(logger, netConfig)
- if err != nil {
- return nil, err
- }
- m.network = network
+ if err := srv.setupNetwork(); err != nil {
+ return nil, err
}
- // start blockchain object
- stateStorage, err := func() (itrie.Storage, error) {
- leveldbBuilder := newLevelDBBuilder(
- logger,
- config,
- filepath.Join(m.config.DataDir, "trie"),
- )
-
- return itrie.NewLevelDBStorage(leveldbBuilder)
- }()
-
- if err != nil {
+ // set up blockchain database
+ if err := srv.setupBlockchainDB(); err != nil {
return nil, err
}
- m.stateStorage = stateStorage
-
- st := itrie.NewStateDB(stateStorage, logger, m.serverMetrics.trie)
- m.state = st
-
- m.executor = state.NewExecutor(config.Chain.Params, st, logger)
- m.executor.SetRuntime(precompiled.NewPrecompiled())
- m.executor.SetRuntime(evm.NewEVM())
-
- // compute the genesis root state
- genesisRoot, err := m.executor.WriteGenesis(config.Chain.Genesis.Alloc)
- if err != nil {
+ // set up state database
+ if err := srv.setupStateDB(); err != nil {
return nil, err
}
- config.Chain.Genesis.StateRoot = genesisRoot
-
- // create leveldb storageBuilder
- leveldbBuilder := newLevelDBBuilder(
- logger,
- config,
- filepath.Join(m.config.DataDir, "blockchain"),
- )
+ // set up trie database
+ srv.setupSnpTrieDB(srv.trieDB)
- // blockchain object
- m.blockchain, err = blockchain.NewBlockchain(
- logger,
- config.Chain,
- kvstorage.NewLevelDBStorageBuilder(logger, leveldbBuilder),
- nil,
- m.executor,
- m.serverMetrics.blockchain,
- )
- if err != nil {
+ // setup executor
+ if err := srv.setupExecutor(); err != nil {
return nil, err
}
- // TODO: refactor the design. Executor and blockchain should not rely on each other.
- m.executor.GetHash = m.blockchain.GetHashHelper
-
- {
- hub := &txpoolHub{
- state: m.state,
- Blockchain: m.blockchain,
- }
-
- blackList := make([]types.Address, len(m.config.Chain.Params.BlackList))
- for i, a := range m.config.Chain.Params.BlackList {
- blackList[i] = types.StringToAddress(a)
- }
+ // setup world state snapshots
+ if err := srv.setupSnapshots(); err != nil {
+ return nil, err
+ }
- // start transaction pool
- m.txpool, err = txpool.NewTxPool(
- logger,
- m.chain.Params.Forks.At(0),
- hub,
- m.grpcServer,
- m.network,
- m.serverMetrics.txpool,
- &txpool.Config{
- Sealing: m.config.Seal,
- MaxSlots: m.config.MaxSlots,
- PriceLimit: m.config.PriceLimit,
- PruneTickSeconds: m.config.PruneTickSeconds,
- PromoteOutdateSeconds: m.config.PromoteOutdateSeconds,
- BlackList: blackList,
- DDOSProtection: m.config.Chain.Params.DDOSProtection,
- },
- )
- if err != nil {
- return nil, err
- }
+ // setup blockchain
+ if err := srv.setupBlockchain(); err != nil {
+ return nil, err
+ }
- // use the eip155 signer
- signer := crypto.NewEIP155Signer(uint64(m.config.Chain.Params.ChainID))
- m.txpool.SetSigner(signer)
+ // Set up txpool
+ if err := srv.setupTxpool(); err != nil {
+ return nil, err
}
- {
- // Setup consensus
- if err := m.setupConsensus(); err != nil {
- return nil, err
- }
- m.blockchain.SetConsensus(m.consensus)
+ // Setup consensus
+ if err := srv.setupConsensus(); err != nil {
+ return nil, err
}
// after consensus is done, we can mine the genesis block in blockchain
// This is done because consensus might use a custom Hash function so we need
// to wait for consensus because we do any block hashing like genesis
- if err := m.blockchain.ComputeGenesis(); err != nil {
+ if err := srv.blockchain.ComputeGenesis(); err != nil {
return nil, err
}
// initialize data in consensus layer
- if err := m.consensus.Initialize(); err != nil {
+ if err := srv.consensus.Initialize(); err != nil {
return nil, err
}
// setup and start jsonrpc server
- if err := m.setupJSONRPC(); err != nil {
+ if err := srv.setupJSONRPC(); err != nil {
return nil, err
}
// setup and start graphql server
- if err := m.setupGraphQL(); err != nil {
+ if err := srv.setupGraphQL(); err != nil {
return nil, err
}
// restore archive data before starting
- if err := m.restoreChain(); err != nil {
+ if err := srv.restoreChain(); err != nil {
return nil, err
}
// start consensus
- if err := m.consensus.Start(); err != nil {
+ if err := srv.consensus.Start(); err != nil {
return nil, err
}
// setup and start grpc server
- if err := m.setupGRPC(); err != nil {
+ if err := srv.setupGRPC(); err != nil {
return nil, err
}
- if err := m.network.Start(); err != nil {
+ // start network to discover peers
+ if err := srv.network.Start(); err != nil {
return nil, err
}
- m.txpool.Start()
+ // start txpool to accept transactions
+ srv.txpool.Start()
+
+ return srv, nil
+}
+
+// setupTxpool set up txpool
+//
+// Must behind other components initilized
+func (s *Server) setupTxpool() error {
+ var (
+ err error
+ config = s.config
+ hub = &txpoolHub{
+ snaps: s.snaps,
+ state: s.state,
+ Blockchain: s.blockchain,
+ }
+ chainCfg = s.config.Chain
+ )
+
+ blackList := make([]types.Address, len(chainCfg.Params.BlackList))
+ for i, a := range chainCfg.Params.BlackList {
+ blackList[i] = types.StringToAddress(a)
+ }
+
+ txpoolCfg := &txpool.Config{
+ Sealing: config.Seal,
+ MaxSlots: config.MaxSlots,
+ PriceLimit: config.PriceLimit,
+ PruneTickSeconds: config.PruneTickSeconds,
+ PromoteOutdateSeconds: config.PromoteOutdateSeconds,
+ BlackList: blackList,
+ DDOSProtection: chainCfg.Params.DDOSProtection,
+ }
+
+ // start transaction pool
+ s.txpool, err = txpool.NewTxPool(
+ s.logger,
+ chainCfg.Params.Forks.At(0),
+ hub,
+ s.grpcServer,
+ s.network,
+ s.serverMetrics.txpool,
+ txpoolCfg,
+ )
+ if err != nil {
+ return err
+ }
+
+ // use the eip155 signer at the beginning
+ signer := crypto.NewEIP155Signer(uint64(chainCfg.Params.ChainID))
+ s.txpool.SetSigner(signer)
+
+ return nil
+}
+
+func (s *Server) GetHashHelper(header *types.Header) func(uint64) types.Hash {
+ return func(u uint64) types.Hash {
+ v, _ := rawdb.ReadCanonicalHash(s.chainDB, u)
+
+ return v
+ }
+}
+
+func (s *Server) setupBlockchain() error {
+ bc, err := blockchain.NewBlockchain(
+ s.logger,
+ s.config.Chain,
+ nil,
+ kvstorage.NewKeyValueStorage(s.chainDB),
+ s.executor,
+ s.serverMetrics.blockchain,
+ )
+ if err != nil {
+ return err
+ }
+
+ // blockchain object
+ s.blockchain = bc
+
+ return nil
+}
+
+// Load any existing snapshot, regenerating it if loading failed
+func (s *Server) setupSnapshots() error {
+ if s.cacheConfig.SnapshotLimit <= 0 {
+ return nil
+ }
+
+ var (
+ needRecover bool
+ logger = s.logger.Named("snapshots")
+ chainDB = s.chainDB
+ trieDB = s.trieDB
+ headStateRoot types.Hash
+ headNumber uint64
+ isGenesis bool
+ )
+
+ headHash, exists := rawdb.ReadHeadHash(chainDB)
+ if !exists {
+ s.logger.Warn("head hash not found, might generate from genesis")
+
+ isGenesis = true
+
+ // get genesis root
+ headStateRoot = s.config.Chain.Genesis.StateRoot
+ if headStateRoot == types.ZeroHash {
+ return nil
+ }
+ }
+
+ if !isGenesis {
+ header, err := rawdb.ReadHeader(chainDB, headHash)
+ if err != nil {
+ s.logger.Error("get header failed", "err", err)
+ os.Exit(1)
+ }
+
+ headNumber = header.Number
+ headStateRoot = header.StateRoot
+ }
+
+ // If the chain was rewound past the snapshot persistent layer (causing a recovery
+ // block number to be persisted to disk), check if we're still in recovery mode
+ // and in that case, don't invalidate the snapshot on a head mismatch.
+ // NOTE: It is impossible in pos chain, the only one exception is hacked database
+ if layer := rawdb.ReadSnapshotRecoveryNumber(trieDB); layer != nil && *layer >= headNumber {
+ s.logger.Warn("Enabling snapshot recovery", "chainhead", headNumber, "diskbase", *layer)
+
+ needRecover = true
+ }
+
+ snapCfg := snapshot.Config{
+ CacheSize: s.cacheConfig.SnapshotLimit,
+ Recovery: needRecover,
+ // background snapshots, if we use it in command line, it should be disabled
+ NoBuild: !s.config.EnableSnapshot,
+ AsyncBuild: !s.cacheConfig.SnapshotWait,
+ }
+
+ snaps, err := snapshot.New(snapCfg, trieDB, s.snpTrieDB, headStateRoot, logger, s.serverMetrics.snapshot)
+ if err != nil {
+ return err
+ }
+
+ // server snaps for rebuild and dump
+ s.snaps = snaps
+
+ // update executor's snaps
+ if s.executor != nil {
+ s.executor.SetSnaps(snaps)
+ }
+
+ return nil
+}
+
+func (s *Server) setupExecutor() error {
+ logger := s.logger.Named("executor")
+ executor := state.NewExecutor(
+ s.config.Chain.Params,
+ logger,
+ s.state,
+ )
+ // other properties
+ executor.SetRuntime(precompiled.NewPrecompiled())
+ executor.SetRuntime(evm.NewEVM())
+ executor.GetHash = s.GetHashHelper
+
+ // compute the genesis root state
+ // TODO: weird to commit every restart
+ genesisRoot, err := executor.WriteGenesis(s.config.Chain.Genesis.Alloc)
+ if err != nil {
+ return err
+ }
+
+ // set executor
+ s.executor = executor
+
+ // update genesis state root
+ s.config.Chain.Genesis.StateRoot = genesisRoot
+
+ return nil
+}
+
+func (s *Server) setupStateDB() error {
+ var (
+ config = s.config
+ logger = s.logger
+ levelOptions = config.LeveldbOptions
+ )
+
+ db, err := leveldb.New(
+ filepath.Join(config.DataDir, StateDataDir),
+ leveldb.SetBloomKeyBits(levelOptions.BloomKeyBits),
+ // trie cache + blockchain cache = leveldbOptions.CacheSize
+ leveldb.SetCacheSize(levelOptions.CacheSize/2),
+ leveldb.SetCompactionTableSize(levelOptions.CompactionTableSize),
+ leveldb.SetCompactionTotalSize(levelOptions.CompactionTotalSize),
+ leveldb.SetHandles(levelOptions.Handles),
+ leveldb.SetLogger(logger.Named("database").With("path", StateDataDir)),
+ leveldb.SetNoSync(levelOptions.NoSync),
+ )
+ if err != nil {
+ return err
+ }
+
+ s.trieDB = db
+
+ st := itrie.NewStateDB(db, logger, s.serverMetrics.trie)
+ s.state = st
+
+ return nil
+}
+
+func (s *Server) setupSnpTrieDB(trieDB kvdb.Database) {
+ cacheConfig := s.cacheConfig
+
+ s.snpTrieDB = trie.NewDatabaseWithConfig(
+ trieDB,
+ &trie.Config{
+ Cache: cacheConfig.TrieCleanLimit,
+ Journal: cacheConfig.TrieCleanJournal,
+ },
+ s.logger.Named("snapshotTrieDB"),
+ )
+}
+
+func (s *Server) setupBlockchainDB() error {
+ var (
+ config = s.config
+ leveldbOptions = config.LeveldbOptions
+ logger = s.logger
+ )
+
+ db, err := leveldb.New(
+ filepath.Join(config.DataDir, BlockchainDataDir),
+ leveldb.SetBloomKeyBits(leveldbOptions.BloomKeyBits),
+ // trie cache + blockchain cache = leveldbOptions.CacheSize
+ leveldb.SetCacheSize(leveldbOptions.CacheSize/2),
+ leveldb.SetCompactionTableSize(leveldbOptions.CompactionTableSize),
+ leveldb.SetCompactionTotalSize(leveldbOptions.CompactionTotalSize),
+ leveldb.SetHandles(leveldbOptions.Handles),
+ leveldb.SetLogger(logger.Named("database").With("path", BlockchainDataDir)),
+ leveldb.SetNoSync(leveldbOptions.NoSync),
+ )
+ if err != nil {
+ return err
+ }
+
+ s.chainDB = db
+
+ return nil
+}
+
+// setupNetwork set up a libp2p network for the server
+func (s *Server) setupNetwork() error {
+ config := s.config
+ netConfig := config.Network
+
+ // more detail configuration
+ netConfig.Chain = s.config.Chain
+ netConfig.DataDir = filepath.Join(s.config.DataDir, "libp2p")
+ netConfig.SecretsManager = s.secretsManager
+ netConfig.Metrics = s.serverMetrics.network
+
+ network, err := network.NewServer(s.logger, netConfig)
+ if err != nil {
+ return err
+ }
+
+ s.network = network
+
+ return nil
+}
+
+// setupMetris set up metrics and metric server
+//
+// Must done before other components
+func (s *Server) setupMetris() {
+ var (
+ config = s.config
+ namespace = "dogechain"
+ chainName = config.Chain.Name
+ )
+
+ if config.Telemetry.PrometheusAddr == nil {
+ s.serverMetrics = metricProvider(namespace, chainName, false, false)
+
+ return
+ }
- return m, nil
+ s.serverMetrics = metricProvider(namespace, chainName, true, config.Telemetry.EnableIOMetrics)
+ s.prometheusServer = s.startPrometheusServer(config.Telemetry.PrometheusAddr)
}
func (s *Server) restoreChain() error {
@@ -362,12 +584,13 @@ func (s *Server) restoreChain() error {
}
type txpoolHub struct {
+ snaps *snapshot.Tree
state state.State
*blockchain.Blockchain
}
func (t *txpoolHub) GetNonce(root types.Hash, addr types.Address) uint64 {
- account, err := getAccountImpl(t.state, root, addr)
+ account, err := getCommittedAccount(t.snaps, t.state, root, addr)
if err != nil {
return 0
}
@@ -376,13 +599,8 @@ func (t *txpoolHub) GetNonce(root types.Hash, addr types.Address) uint64 {
}
func (t *txpoolHub) GetBalance(root types.Hash, addr types.Address) (*big.Int, error) {
- account, err := getAccountImpl(t.state, root, addr)
+ account, err := getCommittedAccount(t.snaps, t.state, root, addr)
if err != nil {
- if errors.Is(err, jsonrpc.ErrStateNotFound) {
- // not exists, stop error propagation
- return big.NewInt(0), nil
- }
-
return big.NewInt(0), err
}
@@ -482,6 +700,7 @@ func (s *Server) setupConsensus() error {
}
s.consensus = consensus
+ s.blockchain.SetConsensus(consensus)
return nil
}
@@ -491,6 +710,7 @@ func (s *Server) setupConsensus() error {
// setupJSONRCP sets up the JSONRPC server, using the set configuration
func (s *Server) setupJSONRPC() error {
hub := NewJSONRPCStore(
+ s.snaps,
s.state,
s.blockchain,
s.restoreProgression,
@@ -538,6 +758,7 @@ func (s *Server) setupGraphQL() error {
}
hub := NewJSONRPCStore(
+ s.snaps,
s.state,
s.blockchain,
s.restoreProgression,
@@ -589,7 +810,7 @@ func (s *Server) setupGRPC() error {
// Chain returns the chain object of the client
func (s *Server) Chain() *chain.Chain {
- return s.chain
+ return s.config.Chain
}
// JoinPeer attempts to add a new peer to the networking server
@@ -614,33 +835,70 @@ func (s *Server) Close() {
// Close the consensus layer
if err := s.consensus.Close(); err != nil {
- s.logger.Error("failed to close consensus", "err", err.Error())
+ s.logger.Error("failed to close consensus", "err", err)
}
- s.logger.Info("close txpool")
-
// close the txpool's main loop
- s.txpool.Close()
-
- s.logger.Info("close network layer")
+ if s.txpool != nil {
+ s.logger.Info("close txpool")
+ s.txpool.Close()
+ }
// Close the networking layer
- if err := s.network.Close(); err != nil {
- s.logger.Error("failed to close networking", "err", err.Error())
+ if s.network != nil {
+ s.logger.Info("close network layer")
+
+ if err := s.network.Close(); err != nil {
+ s.logger.Error("failed to close networking", "err", err)
+ }
+ }
+
+ // Journal state snapshot and flush caches
+ s.journalSnapshots()
+
+ if s.blockchain != nil {
+ s.logger.Info("close blockchain")
+
+ // Close the blockchain layer
+ if err := s.blockchain.Close(); err != nil {
+ s.logger.Error("failed to close blockchain", "err", err)
+ }
}
- s.logger.Info("close state storage")
+ // Close jsonrpc server
+ if s.jsonrpcServer != nil {
+ s.logger.Info("close rpc server")
+
+ if err := s.jsonrpcServer.Close(); err != nil {
+ s.logger.Error("failed to close jsonrpc server", "err", err)
+ }
+ }
+
+ // Close graphql server
+ if s.graphqlServer != nil {
+ s.logger.Info("close graphql server")
+
+ if err := s.graphqlServer.Close(); err != nil {
+ s.logger.Error("failed to close graphql server", "err", err)
+ }
+ }
// Close the state storage
- if err := s.stateStorage.Close(); err != nil {
- s.logger.Error("failed to close storage for trie", "err", err.Error())
+ if s.trieDB != nil {
+ s.logger.Info("close state storage")
+
+ if err := s.trieDB.Close(); err != nil {
+ s.logger.Error("failed to close storage for trie", "err", err)
+ }
}
- s.logger.Info("close blockchain storage")
+ if s.chainDB != nil {
+ s.logger.Info("close blockchain storage")
- // Close the blockchain layer
- if err := s.blockchain.Close(); err != nil {
- s.logger.Error("failed to close blockchain", "err", err.Error())
+ // Close the blockchain storage
+ if err := s.chainDB.Close(); err != nil {
+ s.logger.Error("failed to close storage for blockchain", "err", err)
+ }
}
if s.prometheusServer != nil {
@@ -650,26 +908,64 @@ func (s *Server) Close() {
}
}
-// Entry is a backend configuration entry
-type Entry struct {
- Enabled bool
- Config map[string]interface{}
-}
+func (s *Server) journalSnapshots() {
+ // Snapshot base layer root, used in journalling when restarted.
+ var snapBase types.Hash
+ // Ensure that the entirety of the state snapshot is journalled to disk.
+ if s.snaps != nil {
+ s.logger.Info("Journal state snapshots to disk")
+
+ var err error
+ if snapBase, err = s.snaps.Journal(s.blockchain.Header().StateRoot); err != nil {
+ s.logger.Error("failed to journal state snapshot", "err", err)
+ }
+ }
-// SetupDataDir sets up the dogechain data directory and sub-folders
-func SetupDataDir(dataDir string, paths []string) error {
- if err := createDir(dataDir); err != nil {
- return fmt.Errorf("failed to create data dir: (%s): %w", dataDir, err)
+ // Ensure the state of a recent block is also stored to disk before exiting.
+ // We're writing different states to catch different restart scenarios:
+ // - HEAD: So we don't need to reprocess any blocks in the general case
+ // - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle
+ // - HEAD-127: So we have a hard limit on the number of blocks reexecuted
+ //
+ // Disable it in "archive" mode only
+ // if !s.cacheConfig.TrieDirtyDisabled
+ triedb := s.snpTrieDB
+
+ for _, offset := range []uint64{0, 1, TriesInMemory - 1} {
+ if number := s.blockchain.Header().Number; number > offset {
+ recent, ok := s.blockchain.GetBlockByNumber(number-offset, true)
+ if !ok {
+ s.logger.Error("block not exists", "number", number-offset)
+
+ continue
+ }
+
+ s.logger.Info("Writing cached state to disk",
+ "block", recent.Number(),
+ "hash", recent.Hash(),
+ "root", recent.Header.StateRoot,
+ )
+
+ if err := triedb.Commit(recent.Header.StateRoot, true, nil); err != nil {
+ s.logger.Error("Failed to commit recent state trie", "err", err)
+ }
+ }
}
- for _, path := range paths {
- path := filepath.Join(dataDir, path)
- if err := createDir(path); err != nil {
- return fmt.Errorf("failed to create path: (%s): %w", path, err)
+ // Update snapshot state root
+ if snapBase != types.ZeroHash {
+ s.logger.Info("Writing snapshot state to disk", "root", snapBase)
+
+ if err := triedb.Commit(snapBase, true, nil); err != nil {
+ s.logger.Error("Failed to commit recent state trie", "err", err)
}
}
- return nil
+ // Ensure all live cached entries be saved into disk, so that we can skip
+ // cache warmup when node restarts.
+ if s.cacheConfig.TrieCleanJournal != "" {
+ s.snpTrieDB.SaveCache(s.cacheConfig.TrieCleanJournal)
+ }
}
func (s *Server) startPrometheusServer(listenAddr *net.TCPAddr) *http.Server {
@@ -697,35 +993,92 @@ func (s *Server) startPrometheusServer(listenAddr *net.TCPAddr) *http.Server {
// helper functions
-// createDir creates a file system directory if it doesn't exist
-func createDir(path string) error {
- _, err := os.Stat(path)
- if err != nil && !os.IsNotExist(err) {
- return err
+func addressHash(addr types.Address) types.Hash {
+ return crypto.Keccak256Hash(addr.Bytes())
+}
+
+// getCommittedAccount is used for fetching (cached) account state from both TxPool and JSON-RPC
+//
+// an error means something bad happen
+func getCommittedAccount(
+ snaps *snapshot.Tree,
+ state state.State,
+ stateRoot types.Hash,
+ addr types.Address,
+) (*stypes.Account, error) {
+ if snaps != nil { // avoid crash
+ if snap := snaps.Snapshot(stateRoot); snap != nil { // snap found
+ account, err := snap.Account(addressHash(addr))
+ if account != nil { // account found
+ return account, nil
+ } else {
+ // might not covered yet
+ hclog.Default().Debug("get account from snapshot failed", "address", addr, "err", err)
+ }
+ }
+ }
+
+ // If the snapshot is unavailable or reading from it fails, load from the database.
+ db, err := state.NewSnapshotAt(stateRoot)
+ if err != nil {
+ return nil, err
}
- if os.IsNotExist(err) {
- if err := os.MkdirAll(path, os.ModePerm); err != nil {
- return err
+ account, err := db.GetAccount(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ if account == nil {
+ // create an initialized account for querying
+ account = &stypes.Account{
+ Balance: new(big.Int),
}
}
- return nil
+ return account, nil
}
-// getAccountImpl is used for fetching account state from both TxPool and JSON-RPC
-func getAccountImpl(state state.State, root types.Hash, addr types.Address) (*state.Account, error) {
- snap, err := state.NewSnapshotAt(root)
+// getCommittedStorage is used for fetching (cached) storage from JSON-RPC
+func getCommittedStorage(
+ snaps *snapshot.Tree,
+ stateDB state.State,
+ stateRoot types.Hash,
+ addr types.Address,
+ slot types.Hash,
+) (types.Hash, error) {
+ if snaps != nil { // query cached snapshot
+ if snap := snaps.Snapshot(stateRoot); snap != nil {
+ // query it from cached snapshot
+ enc, err := snap.Storage(crypto.Keccak256Hash(addr.Bytes()), crypto.Keccak256Hash(slot.Bytes()))
+ if err == nil { // found
+ return utils.StorageBytesToHash(enc)
+ } else {
+ // print out error
+ hclog.Default().Debug("failed to get storage", "address", addr, "slot", slot, "err", err)
+ }
+ }
+ }
+
+ // get account at state root, so as to get the correct storage root
+ account, err := getCommittedAccount(snaps, stateDB, stateRoot, addr)
if err != nil {
- return nil, fmt.Errorf("unable to get snapshot for root '%s': %w", root, err)
+ // something bad happen
+ return types.Hash{}, err
}
- account, err := snap.GetAccount(addr)
+ // fast returns
+ if account.StorageRoot == types.ZeroHash ||
+ account.StorageRoot == types.EmptyRootHash {
+ return types.Hash{}, nil
+ }
+
+ // make a snapshot at root
+ db, err := stateDB.NewSnapshotAt(stateRoot)
if err != nil {
- return nil, err
- } else if account == nil {
- return nil, jsonrpc.ErrStateNotFound
+ return types.Hash{}, err
}
- return account, nil
+ // If the snapshot is unavailable or reading from it fails, load from the database.
+ return db.GetStorage(addr, account.StorageRoot, slot)
}
diff --git a/server/server_metrics.go b/server/server_metrics.go
index 0f77db91c5..138c1e9ea6 100644
--- a/server/server_metrics.go
+++ b/server/server_metrics.go
@@ -8,6 +8,7 @@ import (
"github.com/dogechain-lab/dogechain/txpool"
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
)
// serverMetrics holds the metric instances of all sub systems
@@ -17,8 +18,9 @@ type serverMetrics struct {
network *network.Metrics
txpool *txpool.Metrics
jsonrpc *jsonrpc.Metrics
- jsonrpcStore *JSONRPCStoreMetrics
+ jsonrpcStore *jsonrpcStoreMetrics
trie itrie.Metrics
+ snapshot *snapshot.Metrics
}
// metricProvider serverMetric instance for the given ChainID and nameSpace
@@ -32,6 +34,7 @@ func metricProvider(nameSpace string, chainID string, metricsRequired bool, trac
jsonrpc: jsonrpc.GetPrometheusMetrics(nameSpace, "chain_id", chainID),
jsonrpcStore: NewJSONRPCStoreMetrics(nameSpace, "chain_id", chainID),
trie: itrie.GetPrometheusMetrics(nameSpace, trackingIOTimer, "chain_id", chainID),
+ snapshot: snapshot.GetPrometheusMetrics(nameSpace, "chain_id", chainID),
}
}
@@ -43,5 +46,6 @@ func metricProvider(nameSpace string, chainID string, metricsRequired bool, trac
jsonrpc: jsonrpc.NilMetrics(),
jsonrpcStore: JSONRPCStoreNilMetrics(),
trie: itrie.NilMetrics(),
+ snapshot: snapshot.NilMetrics(),
}
}
diff --git a/server/system_service.go b/server/system_service.go
index 1be019ea37..401ff65641 100644
--- a/server/system_service.go
+++ b/server/system_service.go
@@ -32,7 +32,7 @@ func (s *systemService) GetStatus(ctx context.Context, req *empty.Empty) (*proto
header := s.server.blockchain.Header()
status := &proto.ServerStatus{
- Network: int64(s.server.chain.Params.ChainID),
+ Network: int64(s.server.config.Chain.Params.ChainID),
Current: &proto.ServerStatus_Block{
Number: int64(header.Number),
Hash: header.Hash.String(),
diff --git a/state/executor.go b/state/executor.go
index 67d662efb8..ffc5944c74 100644
--- a/state/executor.go
+++ b/state/executor.go
@@ -14,6 +14,8 @@ import (
"github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/state/runtime"
"github.com/dogechain-lab/dogechain/state/runtime/evm"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
)
@@ -25,8 +27,6 @@ const (
TxGasContractCreation uint64 = 53000 // Per transaction that creates a contract
)
-var emptyCodeHashTwo = types.BytesToHash(crypto.Keccak256(nil))
-
// GetHashByNumber returns the hash function of a block number
type GetHashByNumber = func(i uint64) types.Hash
@@ -34,26 +34,42 @@ type GetHashByNumberHelper = func(*types.Header) GetHashByNumber
// Executor is the main entity
type Executor struct {
- logger hclog.Logger
- config *chain.Params
+ config *chain.Params
+ logger hclog.Logger
+
runtimes []runtime.Runtime
state State
GetHash GetHashByNumberHelper
stopped uint32 // atomic flag for stopping
+ // world state snapshots
+ snaps *snapshot.Tree
+
+ // post hook for testing
PostHook func(txn *Transition)
}
// NewExecutor creates a new executor
-func NewExecutor(config *chain.Params, s State, logger hclog.Logger) *Executor {
+func NewExecutor(
+ config *chain.Params,
+ logger hclog.Logger,
+ s State,
+) *Executor {
return &Executor{
- logger: logger,
config: config,
+ logger: logger,
runtimes: []runtime.Runtime{},
state: s,
}
}
+// SetSnaps sets snapshots
+//
+// it should be enable only when snapshots is ready and validated
+func (e *Executor) SetSnaps(snaps *snapshot.Tree) {
+ e.snaps = snaps
+}
+
func (e *Executor) WriteGenesis(alloc map[types.Address]*chain.GenesisAccount) (types.Hash, error) {
snap := e.state.NewSnapshot()
txn := NewTxn(snap)
@@ -166,16 +182,6 @@ func (e *Executor) Stop() {
atomic.StoreUint32(&e.stopped, 1)
}
-// StateAt returns snapshot at given root
-func (e *Executor) State() State {
- return e.state
-}
-
-// StateAt returns snapshot at given root
-func (e *Executor) StateAt(root types.Hash) (Snapshot, error) {
- return e.state.NewSnapshotAt(root)
-}
-
// GetForksInTime returns the active forks at the given block height
func (e *Executor) GetForksInTime(blockNumber uint64) chain.ForksInTime {
return e.config.Forks.At(blockNumber)
@@ -223,6 +229,9 @@ func (e *Executor) BeginTxn(
evmLogger: runtime.NewDummyLogger(),
}
+ // Set the snapshots here, then no external setup is required
+ txn.SetSnapsRoot(e.snaps, parentRoot)
+
return txn, nil
}
@@ -249,6 +258,21 @@ type Transition struct {
// then we wouldn't have to judge any tracing flag
evmLogger runtime.EVMLogger
needDebug bool
+
+ // snaps caches world state snapshots to reduce long reads from database
+ snaps *snapshot.Tree // snapshot tree
+ snap snapshot.Snapshot // current world state
+}
+
+// SetSnapsRoot sets snapshots and current snapshot at root
+func (t *Transition) SetSnapsRoot(snaps *snapshot.Tree, root types.Hash) {
+ t.snaps = snaps
+
+ if snaps != nil {
+ if t.snap = snaps.Snapshot(root); t.snap != nil {
+ t.txn.SetSnap(t.snap)
+ }
+ }
}
// SetEVMLogger sets a non nil tracer to it
@@ -354,8 +378,8 @@ func (t *Transition) Write(txn *types.Transaction) error {
GasUsed: result.GasUsed,
}
- // Byzantium is always on now, otherwise it is not EVM-conpatable.
-
+ // Byzantium is always on now, otherwise it is not EVM-compatible.
+ //
// The suicided accounts are set as deleted for the next iteration
t.txn.CleanDeleteObjects(true)
@@ -411,7 +435,7 @@ func (t *Transition) handleBridgeLogs(msg *types.Transaction, logs []*types.Log)
}
// the total one is the real amount of Withdrawn event
- realAmount := big.NewInt(0).Add(parsedLog.Amount, parsedLog.Fee)
+ realAmount := new(big.Int).Add(parsedLog.Amount, parsedLog.Fee)
if err := t.txn.SubBalance(parsedLog.Contract, realAmount); err != nil {
return err
@@ -435,6 +459,80 @@ func (t *Transition) handleBridgeLogs(msg *types.Transaction, logs []*types.Log)
return nil
}
+func (t *Transition) UpdateSnapshot(root types.Hash, objs []*stypes.Object) {
+ if t.snap == nil {
+ return
+ }
+
+ defer func() {
+ // clear current snap since the state is useless
+ t.snap = nil
+ t.txn.CleanSnap()
+ }()
+
+ // Only update if there's a state transition (skip empty blocks)
+ parent := t.snap.Root()
+ if parent == root {
+ return
+ }
+
+ // get snap objects from txn
+ snapDestructs, snapAccounts, snapStorage := t.txn.GetSnapObjects()
+
+ // update all snap account state root
+ for _, obj := range objs {
+ addrHash := crypto.Keccak256Hash(obj.Address.Bytes())
+
+ if obj.Deleted {
+ // delete account
+ snapDestructs[addrHash] = struct{}{}
+
+ delete(snapAccounts, addrHash)
+ delete(snapStorage, addrHash)
+
+ continue
+ }
+
+ // update snap layer account
+ snapAccounts[addrHash] =
+ snapshot.SlimAccountRLP(
+ obj.Nonce,
+ obj.Balance,
+ obj.Root,
+ obj.CodeHash.Bytes(),
+ )
+ }
+
+ // update snapshot tree
+ if err := t.snaps.Update(
+ root,
+ parent,
+ snapDestructs,
+ snapAccounts,
+ snapStorage,
+ t.logger,
+ ); err != nil {
+ t.logger.Warn(
+ "Failed to update snapshot tree",
+ "from", parent,
+ "to", root,
+ "err", err,
+ )
+ }
+ // Keep 128 diff layers in the memory, persistent layer is 129th.
+ // - head layer is paired with HEAD state
+ // - head-1 layer is paired with HEAD-1 state
+ // - head-127 layer(bottom-most diff layer) is paired with HEAD-127 state
+ if err := t.snaps.Cap(root, 128); err != nil {
+ t.logger.Warn(
+ "Failed to cap snapshot tree",
+ "root", root,
+ "layers", 128,
+ "err", err,
+ )
+ }
+}
+
// Commit commits the final result
func (t *Transition) Commit() (Snapshot, types.Hash, error) {
objs := t.txn.Commit(t.config.EIP155)
@@ -444,6 +542,9 @@ func (t *Transition) Commit() (Snapshot, types.Hash, error) {
return nil, types.Hash{}, err
}
+ // If snapshotting is enabled, update the snapshot tree after committed.
+ t.UpdateSnapshot(types.BytesToHash(root), objs)
+
return s2, types.BytesToHash(root), nil
}
@@ -468,10 +569,6 @@ func (t *Transition) addGasPool(amount uint64) {
t.gasPool += amount
}
-func (t *Transition) SetTxn(txn *Txn) {
- t.txn = txn
-}
-
func (t *Transition) Txn() *Txn {
return t.txn
}
@@ -514,8 +611,20 @@ func (t *Transition) subGasLimitPrice(msg *types.Transaction) error {
return nil
}
+func (t Transition) nonceOverflowCheck(addr types.Address) (uint64, error) {
+ nonce := t.txn.GetNonce(addr)
+ if nonce+1 < nonce {
+ return 0, ErrNonceUintOverflow
+ }
+
+ return nonce, nil
+}
+
func (t *Transition) nonceCheck(msg *types.Transaction) error {
- nonce := t.txn.GetNonce(msg.From)
+ nonce, err := t.nonceOverflowCheck(msg.From)
+ if err != nil {
+ return err
+ }
if msg.Nonce < nonce {
return NewNonceTooLowError(fmt.Errorf("%w, actual: %d, wanted: %d", ErrNonceIncorrect, msg.Nonce, nonce), nonce)
@@ -538,6 +647,7 @@ var (
ErrNotEnoughFunds = errors.New("not enough funds for transfer with given value")
ErrAllGasUsed = errors.New("all gas used")
ErrExecutionStop = errors.New("execution stop")
+ ErrNonceUintOverflow = errors.New("nonce uint64 overflow")
)
type TransitionApplicationError struct {
@@ -814,7 +924,7 @@ func (t *Transition) hasCodeOrNonce(addr types.Address) bool {
codeHash := t.txn.GetCodeHash(addr)
- if codeHash != emptyCodeHashTwo && codeHash != emptyHash {
+ if codeHash != types.EmptyCodeHash && codeHash != emptyHash {
return true
}
@@ -831,6 +941,13 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim
}
}
+ if _, err := t.nonceOverflowCheck(c.Caller); err != nil {
+ return &runtime.ExecutionResult{
+ GasLeft: gasLimit,
+ Err: ErrNonceUintOverflow,
+ }
+ }
+
// Increment the nonce of the caller
t.txn.IncrNonce(c.Caller)
@@ -848,7 +965,7 @@ func (t *Transition) applyCreate(c *runtime.Contract, host runtime.Host) *runtim
if t.config.EIP158 {
// Force the creation of the account
t.txn.CreateAccount(c.Address)
- t.txn.IncrNonce(c.Address)
+ t.txn.IncrNonce(c.Address) // the contract nonce is 1 by default
}
var result *runtime.ExecutionResult
diff --git a/state/immutable-trie/encoding.go b/state/immutable-trie/encoding.go
index f8c0d21b9f..189cfb9af1 100644
--- a/state/immutable-trie/encoding.go
+++ b/state/immutable-trie/encoding.go
@@ -47,11 +47,13 @@ func encodeCompact(hex []byte) []byte {
// (with terminator flag). Prefix flag is not removed.
func bytesToHexNibbles(bytes []byte) []byte {
nibbles := make([]byte, len(bytes)*2+1)
+ // bytes to hex num slice
for i, b := range bytes {
nibbles[i*2] = b / 16
nibbles[i*2+1] = b % 16
}
+ // prefix ending
nibbles[len(nibbles)-1] = 16
return nibbles
diff --git a/state/immutable-trie/snapshot.go b/state/immutable-trie/snapshot.go
index 1d7903069d..9a3e273c1f 100644
--- a/state/immutable-trie/snapshot.go
+++ b/state/immutable-trie/snapshot.go
@@ -1,18 +1,26 @@
package itrie
import (
- "bytes"
"fmt"
"github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/state"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/state/utils"
"github.com/dogechain-lab/dogechain/types"
"github.com/dogechain-lab/fastrlp"
)
type Snapshot struct {
- state StateDB
- trie *Trie
+ stateDB StateDB
+ trie *Trie
+}
+
+func newSnapshotImpl(stateDB StateDB, trie *Trie) *Snapshot {
+ return &Snapshot{
+ stateDB: stateDB,
+ trie: trie,
+ }
}
func (s *Snapshot) GetStorage(addr types.Address, root types.Hash, rawkey types.Hash) (types.Hash, error) {
@@ -22,9 +30,10 @@ func (s *Snapshot) GetStorage(addr types.Address, root types.Hash, rawkey types.
)
if root == types.EmptyRootHash {
- ss = s.state.NewSnapshot()
+ ss = s.stateDB.NewSnapshot()
} else {
- ss, err = s.state.NewSnapshotAt(root)
+ // a new Snapshot on target contract state root
+ ss, err = s.stateDB.NewSnapshotAt(root)
if err != nil {
return types.Hash{}, err
}
@@ -39,34 +48,20 @@ func (s *Snapshot) GetStorage(addr types.Address, root types.Hash, rawkey types.
// slot to hash
key := crypto.Keccak256(rawkey.Bytes())
- val, err := snapshot.trie.Get(key, s.state)
+ val, err := snapshot.trie.Get(key, s.stateDB)
if err != nil {
// something bad happen, should not continue
return types.Hash{}, err
- } else if len(val) == 0 {
- // not found
- return types.Hash{}, nil
- }
-
- p := &fastrlp.Parser{}
-
- v, err := p.Parse(val)
- if err != nil {
- return types.Hash{}, err
}
- res := []byte{}
- if res, err = v.GetBytes(res[:0]); err != nil {
- return types.Hash{}, err
- }
-
- return types.BytesToHash(res), nil
+ // not found should return empty hash
+ return utils.StorageBytesToHash(val)
}
-func (s *Snapshot) GetAccount(addr types.Address) (*state.Account, error) {
- key := crypto.Keccak256(addr.Bytes())
+func (s *Snapshot) GetAccount(addr types.Address) (*stypes.Account, error) {
+ key := addressHash(addr)
- data, err := s.trie.Get(key, s.state)
+ data, err := s.trie.Get(key, s.stateDB)
if err != nil {
return nil, err
} else if data == nil {
@@ -74,7 +69,7 @@ func (s *Snapshot) GetAccount(addr types.Address) (*state.Account, error) {
return nil, nil
}
- var account state.Account
+ var account stypes.Account
if err := account.UnmarshalRlp(data); err != nil {
return nil, err
}
@@ -83,23 +78,23 @@ func (s *Snapshot) GetAccount(addr types.Address) (*state.Account, error) {
}
func (s *Snapshot) GetCode(hash types.Hash) ([]byte, bool) {
- return s.state.GetCode(hash)
+ return s.stateDB.GetCode(hash)
}
-func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error) {
+func (s *Snapshot) Commit(objs []*stypes.Object) (state.Snapshot, []byte, error) {
var (
root []byte = nil
nTrie *Trie = nil
// metrics logger
- metrics = s.state.GetMetrics()
+ metrics = s.stateDB.GetMetrics()
insertCount = 0
deleteCount = 0
newSetCodeCount = 0
)
// Create an insertion batch for all the entries
- err := s.state.Transaction(func(st StateDBTransaction) error {
+ err := s.stateDB.Transaction(func(st StateDBTransaction) error {
defer st.Rollback()
tt := s.trie.Txn(st)
@@ -107,26 +102,25 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
arena := fastrlp.DefaultArenaPool.Get()
defer fastrlp.DefaultArenaPool.Put(arena)
- ar1 := fastrlp.DefaultArenaPool.Get()
- defer fastrlp.DefaultArenaPool.Put(ar1)
-
for _, obj := range objs {
if obj.Deleted {
- err := tt.Delete(hashit(obj.Address.Bytes()))
+ // address hash
+ err := tt.Delete(addressHash(obj.Address))
if err != nil {
return err
}
deleteCount++
} else {
- account := state.Account{
- Balance: obj.Balance,
- Nonce: obj.Nonce,
- CodeHash: obj.CodeHash.Bytes(),
- Root: obj.Root, // old root
+ account := stypes.Account{
+ Balance: obj.Balance,
+ Nonce: obj.Nonce,
+ CodeHash: obj.CodeHash.Bytes(),
+ StorageRoot: obj.Root, // old root
}
if len(obj.Storage) != 0 {
+ // last root
rootsnap, err := st.NewSnapshotAt(obj.Root)
// s.state.newTrieAt(obj.Root)
if err != nil {
@@ -136,10 +130,11 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
// tricky, but necessary here
loadSnap, _ := rootsnap.(*Snapshot)
// create a new Txn since we don't know whether there is any cache in it
- localTxn := loadSnap.trie.Txn(loadSnap.state)
+ localTxn := loadSnap.trie.Txn(loadSnap.stateDB)
for _, entry := range obj.Storage {
- k := hashit(entry.Key)
+ // slot hash
+ k := crypto.Keccak256(entry.Key)
if entry.Deleted {
err := localTxn.Delete(k)
if err != nil {
@@ -148,8 +143,7 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
deleteCount++
} else {
- vv := ar1.NewBytes(bytes.TrimLeft(entry.Val, "\x00"))
- err := localTxn.Insert(k, vv.MarshalTo(nil))
+ err := localTxn.Insert(k, entry.Val)
if err != nil {
return err
}
@@ -167,7 +161,10 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
// end observe account hash time
observe()
- account.Root = types.BytesToHash(accountStateRoot)
+ account.StorageRoot = types.BytesToHash(accountStateRoot)
+
+ // update object state root, so that we could use it later
+ obj.Root = account.StorageRoot
}
if obj.DirtyCode {
@@ -184,7 +181,7 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
vv := account.MarshalWith(arena)
data := vv.MarshalTo(nil)
- tt.Insert(hashit(obj.Address.Bytes()), data)
+ tt.Insert(addressHash(obj.Address), data)
insertCount++
arena.Reset()
@@ -219,5 +216,5 @@ func (s *Snapshot) Commit(objs []*state.Object) (state.Snapshot, []byte, error)
metrics.transactionNewAccountObserve(newSetCodeCount)
}
- return &Snapshot{trie: nTrie, state: s.state}, root, err
+ return newSnapshotImpl(s.stateDB, nTrie), root, err
}
diff --git a/state/immutable-trie/statedb.go b/state/immutable-trie/statedb.go
index 020d1cb454..1788410043 100644
--- a/state/immutable-trie/statedb.go
+++ b/state/immutable-trie/statedb.go
@@ -6,6 +6,7 @@ import (
"sync"
"github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
@@ -13,9 +14,6 @@ import (
)
var (
- // codePrefix is the code prefix for leveldb
- codePrefix = []byte("code")
-
ErrStateTransactionIsCancel = errors.New("transaction is cancel")
)
@@ -90,6 +88,14 @@ func (db *stateDBImpl) Logger() hclog.Logger {
return db.logger
}
+func (db *stateDBImpl) Has(p []byte) (bool, error) {
+ if db.cached.Has(p) {
+ return true, nil
+ }
+
+ return db.storage.Has(p)
+}
+
func (db *stateDBImpl) Get(k []byte) ([]byte, bool, error) {
if enc := db.cached.Get(nil, k); enc != nil {
db.metrics.accountCacheHitInc()
@@ -122,8 +128,8 @@ func (db *stateDBImpl) Get(k []byte) ([]byte, bool, error) {
}
func (db *stateDBImpl) GetCode(hash types.Hash) ([]byte, bool) {
- perfix := append(codePrefix, hash.Bytes()...)
- if enc := db.codeCache.Get(nil, perfix); enc != nil {
+ key := rawdb.CodeKey(hash)
+ if enc := db.codeCache.Get(nil, key); enc != nil {
db.metrics.codeCacheHitInc()
return enc, true
@@ -134,7 +140,7 @@ func (db *stateDBImpl) GetCode(hash types.Hash) ([]byte, bool) {
// start observe disk read time
observe := db.metrics.codeDiskReadSecondsObserve()
- v, ok, err := db.storage.Get(perfix)
+ v, ok, err := db.storage.Get(key)
if err != nil {
db.logger.Error("failed to get code", "err", err)
}
@@ -146,7 +152,7 @@ func (db *stateDBImpl) GetCode(hash types.Hash) ([]byte, bool) {
// write-back cache
if err == nil && ok {
- db.cached.Set(perfix, v)
+ db.cached.Set(key, v)
}
if !ok {
@@ -157,7 +163,7 @@ func (db *stateDBImpl) GetCode(hash types.Hash) ([]byte, bool) {
}
func (db *stateDBImpl) NewSnapshot() state.Snapshot {
- return &Snapshot{state: db, trie: db.newTrie()}
+ return newSnapshotImpl(db, db.newTrie())
}
func (db *stateDBImpl) NewSnapshotAt(root types.Hash) (state.Snapshot, error) {
@@ -166,7 +172,7 @@ func (db *stateDBImpl) NewSnapshotAt(root types.Hash) (state.Snapshot, error) {
return nil, err
}
- return &Snapshot{state: db, trie: t}, nil
+ return newSnapshotImpl(db, t), nil
}
var stateTxnPool = sync.Pool{
diff --git a/state/immutable-trie/statedb_test.go b/state/immutable-trie/statedb_test.go
index c123cc61f2..b9d863511d 100644
--- a/state/immutable-trie/statedb_test.go
+++ b/state/immutable-trie/statedb_test.go
@@ -3,6 +3,7 @@ package itrie
import (
"testing"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
"github.com/dogechain-lab/dogechain/state"
"github.com/hashicorp/go-hclog"
)
@@ -12,8 +13,7 @@ func TestState(t *testing.T) {
}
func buildPreState(pre state.PreStates) state.Snapshot {
- storage := NewMemoryStorage()
- st := NewStateDB(storage, hclog.NewNullLogger(), nil)
+ st := NewStateDB(memorydb.New(), hclog.NewNullLogger(), nil)
snap := st.NewSnapshot()
return snap
diff --git a/state/immutable-trie/statedb_transaction.go b/state/immutable-trie/statedb_transaction.go
index 1ece1530e1..2cb7a7f912 100644
--- a/state/immutable-trie/statedb_transaction.go
+++ b/state/immutable-trie/statedb_transaction.go
@@ -6,6 +6,7 @@ import (
"fmt"
"sync"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/types"
"go.uber.org/atomic"
@@ -47,7 +48,7 @@ func (pair *txnPair) Reset() {
type stateDBTxn struct {
db map[txnKey]*txnPair
- lock sync.Mutex
+ lock sync.Mutex // for protecting map
stateDB StateDB
storage Storage
@@ -56,8 +57,7 @@ type stateDBTxn struct {
}
func (tx *stateDBTxn) Set(k []byte, v []byte) error {
- tx.lock.Lock()
- defer tx.lock.Unlock()
+ key := byteKeyToTxnKey(k)
pair, ok := txnPairPool.Get().(*txnPair)
if !ok {
@@ -67,31 +67,71 @@ func (tx *stateDBTxn) Set(k []byte, v []byte) error {
pair.key = append(pair.key, k...)
pair.value = append(pair.value, v...)
- tx.db[txnKey(hex.EncodeToString(k))] = pair
+ tx.lock.Lock()
+ defer tx.lock.Unlock()
+
+ tx.db[key] = pair
return nil
}
-func (tx *stateDBTxn) Get(k []byte) ([]byte, bool, error) {
+func (tx *stateDBTxn) Delete(k []byte) error {
+ key := byteKeyToTxnKey(k)
+
tx.lock.Lock()
defer tx.lock.Unlock()
- v, ok := tx.db[txnKey(hex.EncodeToString(k))]
- if !ok {
- return tx.stateDB.Get(k)
+ delete(tx.db, key)
+
+ return nil
+}
+
+func byteKeyToTxnKey(k []byte) txnKey {
+ return txnKey(hex.EncodeToString(k))
+}
+
+func (tx *stateDBTxn) Has(k []byte) (bool, error) {
+ key := byteKeyToTxnKey(k)
+
+ tx.lock.Lock()
+
+ if _, ok := tx.db[key]; ok {
+ tx.lock.Unlock()
+
+ return true, nil
}
- bufValue := make([]byte, len(v.value))
- copy(bufValue[:], v.value[:])
+ tx.lock.Unlock()
- return bufValue, true, nil
+ return tx.stateDB.Has(k)
}
-func (tx *stateDBTxn) SetCode(hash types.Hash, v []byte) error {
+func (tx *stateDBTxn) Get(k []byte) ([]byte, bool, error) {
+ key := byteKeyToTxnKey(k)
+
tx.lock.Lock()
- defer tx.lock.Unlock()
- perfix := append(codePrefix, hash.Bytes()...)
+ v, ok := tx.db[key]
+ if ok {
+ // copy value
+ bufValue := make([]byte, len(v.value))
+ copy(bufValue[:], v.value[:])
+
+ // unlock
+ tx.lock.Unlock()
+
+ return bufValue, true, nil
+ }
+
+ tx.lock.Unlock()
+
+ return tx.stateDB.Get(k)
+}
+
+func (tx *stateDBTxn) SetCode(hash types.Hash, v []byte) error {
+ // active code key is different from account key (hash)
+ key := rawdb.CodeKey(hash)
+ keyStr := byteKeyToTxnKey(key)
pair, ok := txnPairPool.Get().(*txnPair)
if !ok {
@@ -99,31 +139,37 @@ func (tx *stateDBTxn) SetCode(hash types.Hash, v []byte) error {
}
// overwrite them
- pair.key = append(pair.key[:0], perfix...)
+ pair.key = append(pair.key[:0], key...)
pair.value = append(pair.value[:0], v...)
pair.isCode = true
- tx.db[txnKey(hex.EncodeToString(perfix))] = pair
+ tx.lock.Lock()
+ defer tx.lock.Unlock()
+
+ tx.db[keyStr] = pair
return nil
}
func (tx *stateDBTxn) GetCode(hash types.Hash) ([]byte, bool) {
+ key := byteKeyToTxnKey(rawdb.CodeKey(hash))
+
tx.lock.Lock()
- defer tx.lock.Unlock()
- perfix := append(codePrefix, hash.Bytes()...)
+ if v, ok := tx.db[key]; ok {
+ // depth copy
+ bufValue := make([]byte, len(v.value))
+ copy(bufValue[:], v.value[:])
- v, ok := tx.db[txnKey(hex.EncodeToString(perfix))]
- if !ok {
- return tx.stateDB.GetCode(hash)
+ // unlock
+ tx.lock.Unlock()
+
+ return bufValue, true
}
- // depth copy
- bufValue := make([]byte, len(v.value))
- copy(bufValue[:], v.value[:])
+ tx.lock.Unlock()
- return bufValue, true
+ return tx.stateDB.GetCode(hash)
}
func (tx *stateDBTxn) NewSnapshot() state.Snapshot {
@@ -148,7 +194,7 @@ func (tx *stateDBTxn) NewSnapshotAt(root types.Hash) (state.Snapshot, error) {
t := NewTrie()
t.root = n
- return &Snapshot{state: tx.stateDB, trie: t}, nil
+ return newSnapshotImpl(tx.stateDB, t), nil
}
func (tx *stateDBTxn) Commit() error {
@@ -157,10 +203,11 @@ func (tx *stateDBTxn) Commit() error {
}
tx.lock.Lock()
- defer tx.lock.Unlock()
// double check
if tx.cancel.Load() {
+ tx.lock.Unlock()
+
return ErrStateTransactionIsCancel
}
@@ -171,6 +218,8 @@ func (tx *stateDBTxn) Commit() error {
err := batch.Set(pair.key, pair.value)
if err != nil {
+ tx.lock.Unlock()
+
return err
}
@@ -179,20 +228,18 @@ func (tx *stateDBTxn) Commit() error {
}
}
- return batch.Commit()
+ tx.lock.Unlock()
+
+ return batch.Write()
}
// clear transaction data, set cancel flag
func (tx *stateDBTxn) Rollback() {
- tx.lock.Lock()
- defer tx.lock.Unlock()
-
- if tx.cancel.Load() {
+ // cancle by atomic swap value
+ if alreadyCancel := tx.cancel.Swap(true); alreadyCancel {
return
}
- tx.cancel.Store(true)
-
tx.clear()
}
@@ -200,6 +247,9 @@ func (tx *stateDBTxn) clear() {
tx.stateDB = nil
tx.storage = nil
+ tx.lock.Lock()
+ defer tx.lock.Unlock()
+
for tk := range tx.db {
pair := tx.db[tk]
pair.Reset()
diff --git a/state/immutable-trie/storage.go b/state/immutable-trie/storage.go
index 5d952e1767..cc902a7e63 100644
--- a/state/immutable-trie/storage.go
+++ b/state/immutable-trie/storage.go
@@ -3,132 +3,14 @@ package itrie
import (
"fmt"
- "github.com/dogechain-lab/dogechain/helper/hex"
"github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
"github.com/dogechain-lab/fastrlp"
)
-var parserPool fastrlp.ParserPool
-
-// Storage stores the trie
-type StorageReader interface {
- Get(k []byte) ([]byte, bool, error)
-}
-
-type StorageWriter interface {
- Set(k, v []byte) error
-}
-
-type Batch interface {
- StorageWriter
-
- Commit() error
-}
-
-// Storage stores the trie
-type Storage interface {
- StorageReader
- StorageWriter
-
- NewBatch() Batch
- Close() error
-}
-
-type kvStorageBatch struct {
- batch kvdb.KVBatch
-}
-
-func (kvBatch *kvStorageBatch) Set(k, v []byte) error {
- kvBatch.batch.Set(k, v)
-
- return nil
-}
-
-func (kvBatch *kvStorageBatch) Commit() error {
- return kvBatch.batch.Write()
-}
-
-// wrap generic kvdb storage to implement Storage interface
-type kvStorage struct {
- db kvdb.KVBatchStorage
-}
-
-func (kv *kvStorage) Get(k []byte) ([]byte, bool, error) {
- return kv.db.Get(k)
-}
-
-func (kv *kvStorage) Set(k, v []byte) error {
- return kv.db.Set(k, v)
-}
-
-func (kv *kvStorage) NewBatch() Batch {
- return &kvStorageBatch{
- batch: kv.db.Batch(),
- }
-}
-
-func (kv *kvStorage) Close() error {
- return kv.db.Close()
-}
-
-func NewLevelDBStorage(leveldbBuilder kvdb.LevelDBBuilder) (Storage, error) {
- db, err := leveldbBuilder.Build()
- if err != nil {
- return nil, err
- }
-
- return &kvStorage{db: db}, nil
-}
-
-type memStorage struct {
- db map[string][]byte
-}
-
-type memBatch struct {
- db *map[string][]byte
-}
-
-// NewMemoryStorage creates an inmemory trie storage
-func NewMemoryStorage() Storage {
- return &memStorage{db: map[string][]byte{}}
-}
-
-func (m *memStorage) Set(p []byte, v []byte) error {
- buf := make([]byte, len(v))
- copy(buf[:], v[:])
- m.db[hex.EncodeToHex(p)] = buf
-
- return nil
-}
-
-func (m *memStorage) Get(p []byte) ([]byte, bool, error) {
- v, ok := m.db[hex.EncodeToHex(p)]
- if !ok {
- return []byte{}, false, nil
- }
-
- return v, true, nil
-}
-
-func (m *memStorage) NewBatch() Batch {
- return &memBatch{db: &m.db}
-}
-
-func (m *memStorage) Close() error {
- return nil
-}
-
-func (m *memBatch) Set(p, v []byte) error {
- buf := make([]byte, len(v))
- copy(buf[:], v[:])
- (*m.db)[hex.EncodeToHex(p)] = buf
-
- return nil
-}
-
-func (m *memBatch) Commit() error {
- return nil
-}
+type StorageReader kvdb.KVReader
+type StorageWriter kvdb.KVWriter
+type Storage kvdb.KVBatchStorage
// GetNode retrieves a node from storage
func GetNode(root []byte, storage StorageReader) (Node, bool, error) {
@@ -139,10 +21,7 @@ func GetNode(root []byte, storage StorageReader) (Node, bool, error) {
// NOTE. We dont need to make copies of the bytes because the nodes
// take the reference from data itself which is a safe copy.
- p := parserPool.Get()
- defer parserPool.Put(p)
-
- v, err := p.Parse(data)
+ v, err := types.RlpUnmarshal(data)
if err != nil {
return nil, false, err
}
@@ -167,9 +46,8 @@ func decodeNode(v *fastrlp.Value) (Node, error) {
var err error
- // TODO remove this once 1.0.4 of ifshort is merged in golangci-lint
- ll := v.Elems() //nolint:ifshort
- if ll == 2 {
+ switch v.Elems() {
+ case 2:
key := v.Get(0)
if key.Type() != fastrlp.TypeBytes {
return nil, fmt.Errorf("short key expected to be bytes")
@@ -197,14 +75,16 @@ func decodeNode(v *fastrlp.Value) (Node, error) {
}
return nc, nil
- } else if ll == 17 {
+ case 17:
// full node
nc := nodePool.GetFullNode()
+
for i := 0; i < 16; i++ {
if v.Get(i).Type() == fastrlp.TypeBytes && len(v.Get(i).Raw()) == 0 {
// empty
continue
}
+
nc.children[i], err = decodeNode(v.Get(i))
if err != nil {
return nil, err
@@ -214,6 +94,7 @@ func decodeNode(v *fastrlp.Value) (Node, error) {
if v.Get(16).Type() != fastrlp.TypeBytes {
return nil, fmt.Errorf("full node value expected to be bytes")
}
+
if len(v.Get(16).Raw()) != 0 {
vv := nodePool.GetValueNode()
vv.buf = append(vv.buf[0:0], v.Get(16).Raw()...)
diff --git a/state/immutable-trie/trie.go b/state/immutable-trie/trie.go
index 0894e0efeb..a98ffacf03 100644
--- a/state/immutable-trie/trie.go
+++ b/state/immutable-trie/trie.go
@@ -20,8 +20,8 @@ func (t *Trie) Get(k []byte, reader StateDBReader) ([]byte, error) {
return txn.Lookup(k)
}
-func hashit(k []byte) []byte {
- return crypto.Keccak256(k)
+func addressHash(addr types.Address) []byte {
+ return crypto.Keccak256(addr.Bytes())
}
// Hash returns the root hash of the trie. It does not write to the
diff --git a/state/journal.go b/state/journal.go
new file mode 100644
index 0000000000..06cf37590a
--- /dev/null
+++ b/state/journal.go
@@ -0,0 +1,190 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package state
+
+import (
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// journalEntry is a modification entry in the state change journal that can be
+// reverted on demand.
+type journalEntry interface {
+ // revert undoes the changes introduced by this journal entry.
+ revert(*Txn)
+
+ // dirtied returns the Ethereum address modified by this journal entry.
+ dirtied() *types.Address
+}
+
+// journal contains the list of state modifications applied since the last state
+// commit. These are tracked to be able to be reverted in the case of an execution
+// exception or request for reversal.
+type journal struct {
+ entries []journalEntry // Current changes tracked by the journal
+ dirties map[types.Address]int // Dirty accounts and the number of changes
+}
+
+// newJournal creates a new initialized journal.
+func newJournal() *journal {
+ return &journal{
+ dirties: make(map[types.Address]int),
+ }
+}
+
+// append inserts a new modification entry to the end of the change journal.
+func (j *journal) append(entry journalEntry) {
+ j.entries = append(j.entries, entry)
+
+ if addr := entry.dirtied(); addr != nil {
+ j.dirties[*addr]++
+ }
+}
+
+// revert undoes a batch of journalled modifications along with any reverted
+// dirty handling too.
+func (j *journal) revert(txn *Txn, snapshot int) {
+ for i := len(j.entries) - 1; i >= snapshot; i-- {
+ // Undo the changes made by the operation
+ j.entries[i].revert(txn)
+
+ // Drop any dirty tracking induced by the change
+ if addr := j.entries[i].dirtied(); addr != nil {
+ if j.dirties[*addr]--; j.dirties[*addr] == 0 {
+ delete(j.dirties, *addr)
+ }
+ }
+ }
+
+ j.entries = j.entries[:snapshot]
+}
+
+// length returns the current number of entries in the journal.
+func (j *journal) length() int {
+ return len(j.entries)
+}
+
+type (
+ // Changes to the account trie.
+ resetObjectChange struct {
+ prev *stateObject
+ prevdestruct bool
+ }
+
+ suicideChange struct {
+ account *types.Address
+ prev bool // whether account had already suicided
+ prevbalance *big.Int
+ }
+
+ // Changes to individual accounts.
+ balanceChange struct {
+ account *types.Address
+ prev *big.Int
+ }
+
+ nonceChange struct {
+ account *types.Address
+ prev uint64
+ }
+
+ codeChange struct {
+ account *types.Address
+ prevcode, prevhash []byte
+ }
+)
+
+func (ch resetObjectChange) revert(t *Txn) {
+ if !ch.prevdestruct && t.snap != nil {
+ delete(t.snapDestructs, ch.prev.AddressHash())
+ }
+}
+
+func (ch resetObjectChange) dirtied() *types.Address {
+ return nil
+}
+
+func addressHash(addr *types.Address) types.Hash {
+ return crypto.Keccak256Hash(addr.Bytes())
+}
+
+func (ch suicideChange) revert(t *Txn) {
+ obj, _ := t.getStateObject(*ch.account)
+ if obj == nil {
+ delete(t.snapAccounts, addressHash(ch.account))
+ } else {
+ obj.suicide = ch.prev
+ // balance
+ obj.setBalance(ch.prevbalance)
+ // revert journaled account
+ t.updateSnapAccount(obj)
+ }
+}
+
+func (ch suicideChange) dirtied() *types.Address {
+ return ch.account
+}
+
+func (ch balanceChange) revert(t *Txn) {
+ obj, _ := t.getStateObject(*ch.account)
+ if obj == nil {
+ delete(t.snapAccounts, addressHash(ch.account))
+ } else {
+ // balance
+ obj.setBalance(ch.prev)
+ // revert journaled account
+ t.updateSnapAccount(obj)
+ }
+}
+
+func (ch balanceChange) dirtied() *types.Address {
+ return ch.account
+}
+
+func (ch nonceChange) revert(t *Txn) {
+ obj, _ := t.getStateObject(*ch.account)
+ if obj == nil {
+ delete(t.snapAccounts, addressHash(ch.account))
+ } else {
+ // nonce
+ obj.setNonce(ch.prev)
+ // revert journaled account
+ t.updateSnapAccount(obj)
+ }
+}
+
+func (ch nonceChange) dirtied() *types.Address {
+ return ch.account
+}
+
+func (ch codeChange) revert(t *Txn) {
+ obj, _ := t.getStateObject(*ch.account)
+ if obj == nil {
+ delete(t.snapAccounts, addressHash(ch.account))
+ } else {
+ // code
+ obj.setCode(types.BytesToHash(ch.prevhash), ch.prevcode)
+ // revert journaled account
+ t.updateSnapAccount(obj)
+ }
+}
+
+func (ch codeChange) dirtied() *types.Address {
+ return ch.account
+}
diff --git a/state/runtime/evm/state_test.go b/state/runtime/evm/state_test.go
index 6f3a873eee..ba6bd88f14 100644
--- a/state/runtime/evm/state_test.go
+++ b/state/runtime/evm/state_test.go
@@ -26,7 +26,7 @@ func (c *codeHelper) pop() {
}
func getState() (*state, func()) {
- c := statePool.Get().(*state) //nolint:forcetypeassert
+ c := statePool.Get().(*state)
return c, func() {
c.reset()
diff --git a/state/runtime/precompiled/precompiled.go b/state/runtime/precompiled/precompiled.go
index 24eaad7099..619b855917 100644
--- a/state/runtime/precompiled/precompiled.go
+++ b/state/runtime/precompiled/precompiled.go
@@ -126,7 +126,7 @@ func (p *Precompiled) Run(c *runtime.Contract, _ runtime.Host, config *chain.For
var zeroPadding = make([]byte, 64)
func (p *Precompiled) leftPad(buf []byte, n int) []byte {
- // TODO, avoid buffer allocation
+ // avoid buffer allocation
l := len(buf)
if l > n {
return buf
diff --git a/state/runtime/runtime.go b/state/runtime/runtime.go
index 994d5a6b0c..139babad30 100644
--- a/state/runtime/runtime.go
+++ b/state/runtime/runtime.go
@@ -58,8 +58,8 @@ func (s StorageStatus) String() string {
// Host is the execution host
type Host interface {
AccountExists(addr types.Address) bool
- GetStorage(addr types.Address, key types.Hash) (types.Hash, error)
- SetStorage(addr types.Address, key types.Hash, value types.Hash, config *chain.ForksInTime) StorageStatus
+ GetStorage(addr types.Address, slot types.Hash) (types.Hash, error)
+ SetStorage(addr types.Address, slot types.Hash, value types.Hash, config *chain.ForksInTime) StorageStatus
GetBalance(addr types.Address) *big.Int
GetCodeSize(addr types.Address) int
GetCodeHash(addr types.Address) types.Hash
diff --git a/state/snapshot/account.go b/state/snapshot/account.go
new file mode 100644
index 0000000000..be08f9f45b
--- /dev/null
+++ b/state/snapshot/account.go
@@ -0,0 +1,78 @@
+package snapshot
+
+import (
+ "bytes"
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Account is a modified version of a state.Account, where the root is replaced
+// with a byte slice. This format can be used to represent full-consensus format
+// or slim-snapshot format which replaces the empty root and code hash as nil
+// byte slice.
+type Account struct {
+ Nonce uint64
+ Balance *big.Int
+ Root []byte
+ CodeHash []byte
+}
+
+// SlimAccount converts a state.Account content into a slim snapshot account
+func SlimAccount(nonce uint64, balance *big.Int, root types.Hash, codehash []byte) Account {
+ slim := Account{
+ Nonce: nonce,
+ Balance: balance,
+ }
+
+ if root != types.EmptyRootHash {
+ slim.Root = root[:]
+ }
+
+ if !bytes.Equal(codehash, types.EmptyCodeHash.Bytes()) {
+ slim.CodeHash = codehash
+ }
+
+ return slim
+}
+
+// SlimAccountRLP converts a state.Account content into a slim snapshot
+// version RLP encoded.
+func SlimAccountRLP(nonce uint64, balance *big.Int, root types.Hash, codehash []byte) []byte {
+ data, err := rlp.EncodeToBytes(SlimAccount(nonce, balance, root, codehash))
+ if err != nil {
+ panic(err)
+ }
+
+ return data
+}
+
+// FullAccount decodes the data on the 'slim RLP' format and return
+// the consensus format account.
+func FullAccount(data []byte) (Account, error) {
+ var account Account
+ if err := rlp.DecodeBytes(data, &account); err != nil {
+ return Account{}, err
+ }
+
+ if len(account.Root) == 0 {
+ account.Root = types.EmptyRootHash.Bytes()
+ }
+
+ if len(account.CodeHash) == 0 {
+ account.CodeHash = types.EmptyCodeHash.Bytes()
+ }
+
+ return account, nil
+}
+
+// FullAccountRLP converts data on the 'slim RLP' format into the full RLP-format.
+func FullAccountRLP(data []byte) ([]byte, error) {
+ account, err := FullAccount(data)
+ if err != nil {
+ return nil, err
+ }
+
+ return rlp.EncodeToBytes(account)
+}
diff --git a/state/snapshot/context.go b/state/snapshot/context.go
new file mode 100644
index 0000000000..6ebe205edb
--- /dev/null
+++ b/state/snapshot/context.go
@@ -0,0 +1,311 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "math"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+const (
+ snapAccount = "account" // Identifier of account snapshot generation
+ snapStorage = "storage" // Identifier of storage snapshot generation
+)
+
+// generatorStats is a collection of statistics gathered by the snapshot generator
+// for logging purposes.
+type generatorStats struct {
+ origin uint64 // Origin prefix where generation started
+ start time.Time // Timestamp when generation started
+ accounts uint64 // Number of accounts indexed(generated or recovered)
+ slots uint64 // Number of storage slots indexed(generated or recovered)
+ dangling uint64 // Number of dangling storage slots
+ storage types.StorageSize // Total account and storage slot size(generation or recovery)
+ logger kvdb.Logger // logger
+ generateMetrics *Metrics // generate metrics
+}
+
+// Log creates an contextual log with the given message and the context pulled
+// from the internally maintained statistics.
+func (gs *generatorStats) Log(msg string, root types.Hash, marker []byte) {
+ var ctx []interface{}
+ if root != (types.Hash{}) {
+ ctx = append(ctx, []interface{}{"root", root}...)
+ }
+
+ // Figure out whether we're after or within an account
+ switch len(marker) {
+ case types.HashLength:
+ ctx = append(ctx, []interface{}{"at", types.BytesToHash(marker)}...)
+ case 2 * types.HashLength:
+ ctx = append(ctx, []interface{}{
+ "in", types.BytesToHash(marker[:types.HashLength]),
+ "at", types.BytesToHash(marker[types.HashLength:]),
+ }...)
+ }
+
+ // Add the usual measurements
+ ctx = append(ctx, []interface{}{
+ "accounts", gs.accounts,
+ "slots", gs.slots,
+ "storage", gs.storage,
+ "dangling", gs.dangling,
+ "elapsed", types.PrettyDuration(time.Since(gs.start)),
+ }...)
+
+ // Calculate the estimated indexing time based on current stats
+ if len(marker) > 0 {
+ if done := binary.BigEndian.Uint64(marker[:8]) - gs.origin; done > 0 {
+ left := math.MaxUint64 - binary.BigEndian.Uint64(marker[:8])
+ speed := done/uint64(time.Since(gs.start)/time.Millisecond+1) + 1 // +1s to avoid division by zero
+ eta := time.Duration(left/speed) * time.Millisecond
+
+ ctx = append(ctx, []interface{}{
+ "eta", types.PrettyDuration(eta),
+ }...)
+
+ // collect metric
+ metrics.SetGauge(gs.generateMetrics.estimateSeconds, eta.Seconds())
+ }
+ }
+
+ gs.logger.Info(msg, ctx...)
+ // used second metric
+ metrics.SetGauge(gs.generateMetrics.usedSeconds, time.Since(gs.start).Seconds())
+}
+
+// generatorContext carries a few global values to be shared by all generation functions.
+type generatorContext struct {
+ stats *generatorStats // Generation statistic collection
+ db kvdb.KVBatchStorage // Key-value store containing the snapshot data
+ account *holdableIterator // Iterator of account snapshot data
+ storage *holdableIterator // Iterator of storage snapshot data
+ batch kvdb.Batch // Database batch for writing batch data atomically
+ logged time.Time // The timestamp when last generation progress was displayed
+ generateMetrics *Metrics // The metric
+}
+
+// newGeneratorContext initializes the context for generation.
+func newGeneratorContext(
+ generateMetrics *Metrics,
+ stats *generatorStats,
+ db kvdb.KVBatchStorage,
+ accMarker []byte,
+ storageMarker []byte,
+) *generatorContext {
+ ctx := &generatorContext{
+ stats: stats,
+ db: db,
+ batch: db.NewBatch(),
+ logged: time.Now(),
+ generateMetrics: generateMetrics,
+ }
+
+ ctx.openIterator(snapAccount, accMarker)
+ ctx.openIterator(snapStorage, storageMarker)
+
+ return ctx
+}
+
+// openIterator constructs global account and storage snapshot iterators
+// at the interrupted position. These iterators should be reopened from time
+// to time to avoid blocking leveldb compaction for a long time.
+func (ctx *generatorContext) openIterator(kind string, start []byte) {
+ if kind == snapAccount {
+ iter := ctx.db.NewIterator(rawdb.SnapshotAccountPrefix, start)
+ ctx.account = newHoldableIterator(
+ rawdb.NewKeyLengthIterator(iter, rawdb.SnapshotPrefixLength+types.HashLength),
+ )
+
+ return
+ }
+
+ iter := ctx.db.NewIterator(rawdb.SnapshotStoragePrefix, start)
+ ctx.storage = newHoldableIterator(
+ rawdb.NewKeyLengthIterator(iter, rawdb.SnapshotPrefixLength+2*types.HashLength),
+ )
+}
+
+// reopenIterator releases the specified snapshot iterator and re-open it
+// in the next position. It's aimed for not blocking leveldb compaction.
+func (ctx *generatorContext) reopenIterator(kind string) {
+ // Shift iterator one more step, so that we can reopen
+ // the iterator at the right position.
+ var iter = ctx.account
+ if kind == snapStorage {
+ iter = ctx.storage
+ }
+
+ hasNext := iter.Next()
+
+ if !hasNext {
+ // Iterator exhausted, release forever and create an already exhausted virtual iterator
+ iter.Release()
+
+ if kind == snapAccount {
+ ctx.account = newHoldableIterator(memorydb.New().NewIterator(nil, nil))
+
+ return
+ }
+
+ ctx.storage = newHoldableIterator(memorydb.New().NewIterator(nil, nil))
+
+ return
+ }
+
+ next := iter.Key()
+ iter.Release()
+ ctx.openIterator(kind, next[rawdb.SnapshotPrefixLength:])
+}
+
+// close releases all the held resources.
+func (ctx *generatorContext) close() {
+ ctx.account.Release()
+ ctx.storage.Release()
+}
+
+// iterator returns the corresponding iterator specified by the kind.
+func (ctx *generatorContext) iterator(kind string) *holdableIterator {
+ if kind == snapAccount {
+ return ctx.account
+ }
+
+ return ctx.storage
+}
+
+// removeStorageBefore deletes all storage entries which are located before
+// the specified account. When the iterator touches the storage entry which
+// is located in or outside the given account, it stops and holds the current
+// iterated element locally.
+func (ctx *generatorContext) removeStorageBefore(account types.Hash) {
+ var (
+ count uint64
+ start = time.Now()
+ iter = ctx.storage
+ )
+
+ for iter.Next() {
+ key := iter.Key()
+
+ // the key length already set, dont worry about slice out of bound
+ if bytes.Compare(
+ key[rawdb.SnapshotPrefixLength:rawdb.SnapshotPrefixLength+types.HashLength],
+ account.Bytes(),
+ ) >= 0 {
+ iter.Hold()
+
+ break
+ }
+
+ count++
+
+ ctx.batch.Delete(key)
+
+ if ctx.batch.ValueSize() > kvdb.IdealBatchSize {
+ ctx.batch.Write()
+ ctx.batch.Reset()
+ }
+ }
+
+ ctx.stats.dangling += count
+
+ metrics.HistogramObserve(ctx.generateMetrics.storageCleanNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+}
+
+// removeStorageAt deletes all storage entries which are located in the specified
+// account. When the iterator touches the storage entry which is outside the given
+// account, it stops and holds the current iterated element locally. An error will
+// be returned if the initial position of iterator is not in the given account.
+func (ctx *generatorContext) removeStorageAt(account types.Hash) error {
+ var (
+ count int64
+ start = time.Now()
+ iter = ctx.storage
+ )
+
+ for iter.Next() {
+ key := iter.Key()
+ // the key length already set, dont worry about slice out of bound
+ cmp := bytes.Compare(
+ key[rawdb.SnapshotPrefixLength:rawdb.SnapshotPrefixLength+types.HashLength],
+ account.Bytes(),
+ )
+
+ if cmp < 0 {
+ return errors.New("invalid iterator position")
+ }
+
+ if cmp > 0 {
+ iter.Hold()
+
+ break
+ }
+
+ count++
+
+ ctx.batch.Delete(key)
+
+ if ctx.batch.ValueSize() > kvdb.IdealBatchSize {
+ ctx.batch.Write()
+ ctx.batch.Reset()
+ }
+ }
+
+ // collect metrics
+ metrics.AddCounter(ctx.generateMetrics.wipedStorageCount, float64(count))
+ metrics.HistogramObserve(ctx.generateMetrics.storageCleanNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+
+ return nil
+}
+
+// removeStorageLeft deletes all storage entries which are located after
+// the current iterator position.
+func (ctx *generatorContext) removeStorageLeft() {
+ var (
+ count uint64
+ start = time.Now()
+ iter = ctx.storage
+ )
+
+ for iter.Next() {
+ count++
+
+ ctx.batch.Delete(iter.Key())
+
+ if ctx.batch.ValueSize() > kvdb.IdealBatchSize {
+ ctx.batch.Write()
+ ctx.batch.Reset()
+ }
+ }
+
+ // collect metrics
+ metrics.AddCounter(ctx.generateMetrics.danglingStorageCount, float64(count))
+ metrics.HistogramObserve(ctx.generateMetrics.storageCleanNanoseconds, float64(time.Since(start).Nanoseconds()))
+
+ ctx.stats.dangling += count
+}
diff --git a/state/snapshot/conversion.go b/state/snapshot/conversion.go
new file mode 100644
index 0000000000..f7611a819b
--- /dev/null
+++ b/state/snapshot/conversion.go
@@ -0,0 +1,402 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "runtime"
+ "sync"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// trieKV represents a trie key-value pair
+type trieKV struct {
+ key types.Hash
+ value []byte
+}
+
+type (
+ // trieGeneratorFn is the interface of trie generation which can
+ // be implemented by different trie algorithm.
+ trieGeneratorFn func(db kvdb.KVWriter, scheme trie.NodeScheme, owner types.Hash,
+ in chan (trieKV), out chan (types.Hash))
+
+ // leafCallbackFn is the callback invoked at the leaves of the trie,
+ // returns the subtrie root with the specified subtrie identifier.
+ leafCallbackFn func(db kvdb.KVWriter, accountHash, codeHash types.Hash, stat *generateStats) (types.Hash, error)
+)
+
+// generateStats is a collection of statistics gathered by the trie generator
+// for logging purposes.
+type generateStats struct {
+ head types.Hash
+ start time.Time
+
+ accounts uint64 // Number of accounts done (including those being crawled)
+ slots uint64 // Number of storage slots done (including those being crawled)
+
+ slotsStart map[types.Hash]time.Time // Start time for account slot crawling
+ slotsHead map[types.Hash]types.Hash // Slot head for accounts being crawled
+
+ lock sync.RWMutex
+ logger kvdb.Logger
+ generateMetrics *Metrics
+}
+
+// newGenerateStats creates a new generator stats.
+func newGenerateStats(logger kvdb.Logger, generateMetrics *Metrics) *generateStats {
+ return &generateStats{
+ slotsStart: make(map[types.Hash]time.Time),
+ slotsHead: make(map[types.Hash]types.Hash),
+ start: time.Now(),
+ logger: logger,
+ generateMetrics: generateMetrics,
+ }
+}
+
+// progressAccounts updates the generator stats for the account range.
+func (stat *generateStats) progressAccounts(account types.Hash, done uint64) {
+ stat.lock.Lock()
+ defer stat.lock.Unlock()
+
+ stat.accounts += done
+ stat.head = account
+}
+
+// finishAccounts updates the generator stats for the finished account range.
+func (stat *generateStats) finishAccounts(done uint64) {
+ stat.lock.Lock()
+ defer stat.lock.Unlock()
+
+ stat.accounts += done
+}
+
+// progressContract updates the generator stats for a specific in-progress contract.
+func (stat *generateStats) progressContract(account types.Hash, slot types.Hash, done uint64) {
+ stat.lock.Lock()
+ defer stat.lock.Unlock()
+
+ stat.slots += done
+ stat.slotsHead[account] = slot
+
+ if _, ok := stat.slotsStart[account]; !ok {
+ stat.slotsStart[account] = time.Now()
+ }
+}
+
+// finishContract updates the generator stats for a specific just-finished contract.
+func (stat *generateStats) finishContract(account types.Hash, done uint64) {
+ stat.lock.Lock()
+ defer stat.lock.Unlock()
+
+ stat.slots += done
+ delete(stat.slotsHead, account)
+ delete(stat.slotsStart, account)
+}
+
+// report prints the cumulative progress statistic smartly.
+func (stat *generateStats) report() {
+ stat.lock.RLock()
+ defer stat.lock.RUnlock()
+
+ ctx := []interface{}{
+ "accounts", stat.accounts,
+ "slots", stat.slots,
+ "elapsed", types.PrettyDuration(time.Since(stat.start)),
+ }
+
+ if stat.accounts > 0 {
+ // If there's progress on the account trie, estimate the time to finish crawling it
+ if done := binary.BigEndian.Uint64(stat.head[:8]) / stat.accounts; done > 0 {
+ var (
+ left = (math.MaxUint64 - binary.BigEndian.Uint64(stat.head[:8])) / stat.accounts
+ speed = done/uint64(time.Since(stat.start)/time.Millisecond+1) + 1 // +1s to avoid division by zero
+ eta = time.Duration(left/speed) * time.Millisecond
+ )
+
+ // If there are large contract crawls in progress, estimate their finish time
+ for acc, head := range stat.slotsHead {
+ start := stat.slotsStart[acc]
+
+ if done := binary.BigEndian.Uint64(head[:8]); done > 0 {
+ var (
+ left = math.MaxUint64 - binary.BigEndian.Uint64(head[:8])
+ speed = done/uint64(time.Since(start)/time.Millisecond+1) + 1 // +1s to avoid division by zero
+ )
+ // Override the ETA if larger than the largest until now
+ if slotETA := time.Duration(left/speed) * time.Millisecond; eta < slotETA {
+ eta = slotETA
+ }
+ }
+ }
+
+ ctx = append(ctx, []interface{}{
+ "eta", types.PrettyDuration(eta),
+ }...)
+
+ // collect metric
+ metrics.SetGauge(stat.generateMetrics.estimateSeconds, eta.Seconds())
+ }
+ }
+
+ // collect metric
+ metrics.SetGauge(stat.generateMetrics.usedSeconds, time.Since(stat.start).Seconds())
+
+ stat.logger.Info("Iterating state snapshot", ctx...)
+}
+
+// reportDone prints the last log when the whole generation is finished.
+func (stat *generateStats) reportDone() {
+ stat.lock.RLock()
+ defer stat.lock.RUnlock()
+
+ var ctx []interface{}
+ ctx = append(ctx, []interface{}{"accounts", stat.accounts}...)
+
+ if stat.slots != 0 {
+ ctx = append(ctx, []interface{}{"slots", stat.slots}...)
+ }
+
+ // collect total used time
+ metrics.SetGauge(stat.generateMetrics.usedSeconds, time.Since(stat.start).Seconds())
+
+ ctx = append(ctx, []interface{}{"elapsed", types.PrettyDuration(time.Since(stat.start))}...)
+ stat.logger.Info("Iterated snapshot", ctx...)
+}
+
+// runReport periodically prints the progress information.
+func runReport(stats *generateStats, stop chan bool) {
+ timer := time.NewTimer(0)
+ defer timer.Stop()
+
+ for {
+ select {
+ case <-timer.C:
+ stats.report()
+ timer.Reset(time.Second * 8)
+ case success := <-stop:
+ if success {
+ stats.reportDone()
+ }
+
+ return
+ }
+ }
+}
+
+// generateTrieRoot generates the trie hash based on the snapshot iterator.
+// It can be used for generating account trie, storage trie or even the
+// whole state which connects the accounts and the corresponding storages.
+func generateTrieRoot(
+ db kvdb.KVWriter,
+ scheme trie.NodeScheme,
+ it Iterator,
+ account types.Hash,
+ generatorFn trieGeneratorFn,
+ leafCallback leafCallbackFn,
+ stats *generateStats,
+ report bool,
+) (types.Hash, error) {
+ var (
+ in = make(chan trieKV) // chan to pass leaves
+ out = make(chan types.Hash, 1) // chan to collect result
+ stoplog = make(chan bool, 1) // 1-size buffer, works when logging is not enabled
+ wg sync.WaitGroup
+ )
+
+ // Spin up a go-routine for trie hash re-generation
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ generatorFn(db, scheme, account, in, out)
+ }()
+
+ // Spin up a go-routine for progress logging
+ if report && stats != nil {
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ runReport(stats, stoplog)
+ }()
+ }
+
+ // Create a semaphore to assign tasks and collect results through. We'll pre-
+ // fill it with nils, thus using the same channel for both limiting concurrent
+ // processing and gathering results.
+ threads := runtime.NumCPU()
+ results := make(chan error, threads)
+
+ for i := 0; i < threads; i++ {
+ results <- nil // fill the semaphore
+ }
+
+ // stop is a helper function to shutdown the background threads
+ // and return the re-generated trie hash.
+ stop := func(fail error) (types.Hash, error) {
+ close(in)
+
+ result := <-out
+
+ for i := 0; i < threads; i++ {
+ if err := <-results; err != nil && fail == nil {
+ fail = err
+ }
+ }
+
+ stoplog <- fail == nil
+
+ wg.Wait()
+
+ return result, fail
+ }
+
+ var (
+ logged = time.Now()
+ processed = uint64(0)
+ leaf trieKV
+ )
+
+ // Start to feed leaves
+ for it.Next() {
+ if account == (types.Hash{}) {
+ var (
+ err error
+ fullData []byte
+ )
+
+ if leafCallback == nil {
+ fullData, err = FullAccountRLP(it.(AccountIterator).Account())
+ if err != nil {
+ return stop(err)
+ }
+ } else {
+ // Wait until the semaphore allows us to continue, aborting if
+ // a sub-task failed
+ if err := <-results; err != nil {
+ results <- nil // stop will drain the results, add a noop back for this error we just consumed
+
+ return stop(err)
+ }
+
+ // Fetch the next account and process it concurrently
+ account, err := FullAccount(it.(AccountIterator).Account())
+ if err != nil {
+ return stop(err)
+ }
+
+ go func(hash types.Hash) {
+ subroot, err := leafCallback(db, hash, types.BytesToHash(account.CodeHash), stats)
+ if err != nil {
+ results <- err
+
+ return
+ }
+
+ if !bytes.Equal(account.Root, subroot.Bytes()) {
+ results <- fmt.Errorf("invalid subroot(path %s), want %s, have %s",
+ hash, account.Root, subroot)
+
+ return
+ }
+
+ results <- nil
+ }(it.Hash())
+
+ fullData, err = rlp.EncodeToBytes(account)
+ if err != nil {
+ return stop(err)
+ }
+ }
+
+ leaf = trieKV{it.Hash(), fullData}
+ } else {
+ //nolint:forcetypeassert
+ leaf = trieKV{it.Hash(), types.CopyBytes(it.(StorageIterator).Slot())}
+ }
+
+ in <- leaf
+
+ // Accumulate the generation statistic if it's required.
+ processed++
+
+ if time.Since(logged) > 3*time.Second && stats != nil {
+ if account == (types.Hash{}) {
+ stats.progressAccounts(it.Hash(), processed)
+ } else {
+ stats.progressContract(account, it.Hash(), processed)
+ }
+
+ logged, processed = time.Now(), 0
+ }
+ }
+
+ // Commit the last part statistic.
+ if processed > 0 && stats != nil {
+ if account == (types.Hash{}) {
+ stats.finishAccounts(processed)
+ } else {
+ stats.finishContract(account, processed)
+ }
+ }
+
+ return stop(nil)
+}
+
+func stackTrieGenerate(
+ db kvdb.KVWriter,
+ scheme trie.NodeScheme,
+ owner types.Hash,
+ in chan trieKV,
+ out chan types.Hash,
+) {
+ var nodeWriter trie.NodeWriteFunc
+
+ if db != nil {
+ nodeWriter = func(owner types.Hash, path []byte, hash types.Hash, blob []byte) {
+ scheme.WriteTrieNode(db, owner, path, hash, blob)
+ }
+ }
+
+ t := trie.NewStackTrieWithOwner(nodeWriter, owner)
+
+ for leaf := range in {
+ t.TryUpdate(leaf.key[:], leaf.value)
+ }
+
+ var root types.Hash
+
+ if db == nil {
+ root = t.Hash()
+ } else {
+ root, _ = t.Commit()
+ }
+
+ out <- root
+}
diff --git a/state/snapshot/difflayer.go b/state/snapshot/difflayer.go
new file mode 100644
index 0000000000..111637061a
--- /dev/null
+++ b/state/snapshot/difflayer.go
@@ -0,0 +1,645 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/big"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+ bloomfilter "github.com/holiman/bloomfilter/v2"
+)
+
+var (
+ // aggregatorMemoryLimit is the maximum size of the bottom-most diff layer
+ // that aggregates the writes from above until it's flushed into the disk
+ // layer.
+ //
+ // Note, bumping this up might drastically increase the size of the bloom
+ // filters that's stored in every diff layer. Don't do that without fully
+ // understanding all the implications.
+ aggregatorMemoryLimit = uint64(4 * 1024 * 1024)
+
+ // aggregatorItemLimit is an approximate number of items that will end up
+ // in the agregator layer before it's flushed out to disk. A plain account
+ // weighs around 14B (+hash), a storage slot 32B (+hash), a deleted slot
+ // 0B (+hash). Slots are mostly set/unset in lockstep, so that average at
+ // 16B (+hash). All in all, the average entry seems to be 15+32=47B. Use a
+ // smaller number to be on the safe side.
+ aggregatorItemLimit = aggregatorMemoryLimit / 42
+
+ // bloomTargetError is the target false positive rate when the aggregator
+ // layer is at its fullest. The actual value will probably move around up
+ // and down from this number, it's mostly a ballpark figure.
+ //
+ // Note, dropping this down might drastically increase the size of the bloom
+ // filters that's stored in every diff layer. Don't do that without fully
+ // understanding all the implications.
+ bloomTargetError = 0.02
+
+ // bloomSize is the ideal bloom filter size given the maximum number of items
+ // it's expected to hold and the target false positive error rate.
+ bloomSize = math.Ceil(float64(aggregatorItemLimit) * math.Log(bloomTargetError) / math.Log(1/math.Pow(2, math.Log(2))))
+
+ // bloomFuncs is the ideal number of bits a single entry should set in the
+ // bloom filter to keep its size to a minimum (given it's size and maximum
+ // entry count).
+ bloomFuncs = math.Round((bloomSize / float64(aggregatorItemLimit)) * math.Log(2))
+
+ // the bloom offsets are runtime constants which determines which part of the
+ // account/storage hash the hasher functions looks at, to determine the
+ // bloom key for an account/slot. This is randomized at init(), so that the
+ // global population of nodes do not all display the exact same behaviour with
+ // regards to bloom content
+ bloomDestructHasherOffset = 0
+ bloomAccountHasherOffset = 0
+ bloomStorageHasherOffset = 0
+)
+
+// diffLayer represents a collection of modifications made to a state snapshot
+// after running a block on top. It contains one sorted list for the account trie
+// and one-one list for each storage tries.
+//
+// The goal of a diff layer is to act as a journal, tracking recent modifications
+// made to the state, that have not yet graduated into a semi-immutable state.
+type diffLayer struct {
+ origin *diskLayer // Base disk layer to directly use on bloom misses
+ parent snapshot // Parent snapshot modified by this one, never nil
+ memory uint64 // Approximate guess as to how much memory we use
+
+ root types.Hash // Root hash to which this snapshot diff belongs to
+ stale uint32 // Signals that the layer became stale (state progressed)
+
+ // destructSet is a very special helper marker. If an account is marked as
+ // deleted, then it's recorded in this set. However it's allowed that an account
+ // is included here but still available in other sets(e.g. storageData). The
+ // reason is the diff layer includes all the changes in a *block*. It can
+ // happen that in the tx_1, account A is self-destructed while in the tx_2
+ // it's recreated. But we still need this marker to indicate the "old" A is
+ // deleted, all data in other set belongs to the "new" A.
+ // Keyed markers for deleted (and potentially) recreated accounts
+ destructSet map[types.Hash]struct{}
+ // List of account for iteration. If it exists, it's sorted, otherwise it's nil
+ accountList []types.Hash
+ // Keyed accounts for direct retrieval (nil means deleted)
+ accountData map[types.Hash][]byte
+ // List of storage slots for iterated retrievals, one per account. Any existing lists are sorted if non-nil
+ storageList map[types.Hash][]types.Hash
+ // Keyed storage slots for direct retrieval. one per account (nil means deleted)
+ storageData map[types.Hash]map[types.Hash][]byte
+
+ // Bloom filter tracking all the diffed items up to the disk layer
+ diffed *bloomfilter.Filter
+
+ lock sync.RWMutex
+
+ // shared logger for print out debug info
+ logger kvdb.Logger
+ // shared snapshot metrics
+ snapmetrics *Metrics
+}
+
+// newDiffLayer creates a new diff on top of an existing snapshot, whether that's a low
+// level persistent database or a hierarchical diff already.
+func newDiffLayer(
+ parent snapshot,
+ root types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ logger kvdb.Logger,
+ snapmetrics *Metrics,
+) *diffLayer {
+ // Create the new layer with some pre-allocated data segments
+ dl := &diffLayer{
+ parent: parent,
+ root: root,
+ destructSet: destructs,
+ accountData: accounts,
+ storageData: storage,
+ storageList: make(map[types.Hash][]types.Hash),
+ logger: logger,
+ snapmetrics: snapmetrics,
+ }
+
+ switch parent := parent.(type) {
+ case *diskLayer:
+ dl.rebloom(parent)
+ case *diffLayer:
+ dl.rebloom(parent.origin)
+ default:
+ panic("unknown parent type")
+ }
+
+ // Sanity check that accounts or storage slots are never nil
+ for accountHash, blob := range accounts {
+ if blob == nil {
+ panic(fmt.Sprintf("account %s nil", accountHash))
+ }
+
+ // Determine memory size and track the dirty writes
+ dl.memory += uint64(types.HashLength + len(blob))
+ metrics.HistogramObserve(dl.snapmetrics.dirtyAccountWriteSize, float64(len(blob)))
+ }
+
+ for accountHash, slots := range storage {
+ if slots == nil {
+ panic(fmt.Sprintf("storage %s nil", accountHash))
+ }
+ // Determine memory size and track the dirty writes
+ for _, data := range slots {
+ dl.memory += uint64(types.HashLength + len(data))
+ metrics.HistogramObserve(dl.snapmetrics.dirtyStorageWriteSize, float64(len(data)))
+ }
+ }
+
+ dl.memory += uint64(len(destructs) * types.HashLength)
+
+ return dl
+}
+
+// Root returns the root hash for which this snapshot was made.
+func (dl *diffLayer) Root() types.Hash {
+ return dl.root
+}
+
+// Parent returns the subsequent layer of a diff layer.
+func (dl *diffLayer) Parent() snapshot {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.parent
+}
+
+// Stale return whether this layer has become stale (was flattened across) or if
+// it's still live.
+func (dl *diffLayer) Stale() bool {
+ return atomic.LoadUint32(&dl.stale) != 0
+}
+
+// Account directly retrieves the account associated with a particular hash in
+// the snapshot slim data format.
+func (dl *diffLayer) Account(hash types.Hash) (*stypes.Account, error) {
+ data, err := dl.AccountRLP(hash)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(data) == 0 { // can be both nil and []byte{}
+ return nil, nil
+ }
+
+ // slim account
+ var slimAccount Account
+ if err := rlp.DecodeBytes(data, &slimAccount); err != nil {
+ panic(err)
+ }
+
+ // copy balance to heap
+ balance := new(big.Int).Set(slimAccount.Balance)
+
+ return &stypes.Account{
+ Nonce: slimAccount.Nonce,
+ Balance: balance,
+ StorageRoot: types.BytesToHash(slimAccount.Root),
+ CodeHash: slimAccount.CodeHash,
+ }, nil
+}
+
+// AccountRLP directly retrieves the account RLP associated with a particular
+// hash in the snapshot slim data format.
+//
+// Note the returned account is not a copy, please don't modify it.
+func (dl *diffLayer) AccountRLP(hash types.Hash) ([]byte, error) {
+ // Check the bloom filter first whether there's even a point in reaching into
+ // all the maps in all the layers below
+ dl.lock.RLock()
+
+ hit := dl.diffed.Contains(accountBloomHasher(hash))
+ if !hit {
+ hit = dl.diffed.Contains(destructBloomHasher(hash))
+ }
+
+ var origin *diskLayer
+
+ if !hit {
+ origin = dl.origin // extract origin while holding the lock
+ }
+
+ dl.lock.RUnlock()
+
+ // If the bloom filter misses, don't even bother with traversing the memory
+ // diff layers, reach straight into the bottom persistent disk layer
+ if origin != nil {
+ metrics.CounterInc(dl.snapmetrics.bloomAccountMissCount)
+
+ return origin.AccountRLP(hash)
+ }
+
+ // The bloom filter hit, start poking in the internal maps
+ return dl.accountRLP(hash, 0)
+}
+
+// accountRLP is an internal version of AccountRLP that skips the bloom filter
+// checks and uses the internal maps to try and retrieve the data. It's meant
+// to be used if a higher layer's bloom filter hit already.
+func (dl *diffLayer) accountRLP(hash types.Hash, depth int) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.Stale() {
+ return nil, ErrSnapshotStale
+ }
+ // If the account is known locally, return it
+ if data, ok := dl.accountData[hash]; ok {
+ metrics.CounterInc(dl.snapmetrics.dirtyAccountHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.dirtyAccountHitDepth, float64(depth))
+ metrics.HistogramObserve(dl.snapmetrics.dirtyAccountReadSize, float64(len(data)))
+ metrics.CounterInc(dl.snapmetrics.bloomAccountTrueHitCount)
+
+ return data, nil
+ }
+ // If the account is known locally, but deleted, return it
+ if _, ok := dl.destructSet[hash]; ok {
+ metrics.CounterInc(dl.snapmetrics.dirtyAccountHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.dirtyAccountHitDepth, float64(depth))
+ metrics.CounterInc(dl.snapmetrics.dirtyAccountInexCount)
+ metrics.CounterInc(dl.snapmetrics.bloomAccountTrueHitCount)
+
+ return nil, nil
+ }
+ // Account unknown to this diff, resolve from parent
+ if diff, ok := dl.parent.(*diffLayer); ok {
+ return diff.accountRLP(hash, depth+1)
+ }
+
+ // Failed to resolve through diff layers, mark a bloom error and use the disk
+ metrics.CounterInc(dl.snapmetrics.bloomAccountFalseHitCount)
+
+ return dl.parent.AccountRLP(hash)
+}
+
+// Storage directly retrieves the storage data associated with a particular hash,
+// within a particular account. If the slot is unknown to this diff, it's parent
+// is consulted.
+//
+// Note the returned slot is not a copy, please don't modify it.
+func (dl *diffLayer) Storage(accountHash, storageHash types.Hash) ([]byte, error) {
+ // Check the bloom filter first whether there's even a point in reaching into
+ // all the maps in all the layers below
+ dl.lock.RLock()
+
+ hit := dl.diffed.Contains(storageBloomHasher{accountHash, storageHash})
+ if !hit {
+ hit = dl.diffed.Contains(destructBloomHasher(accountHash))
+ }
+
+ var origin *diskLayer
+ if !hit {
+ origin = dl.origin // extract origin while holding the lock
+ }
+
+ dl.lock.RUnlock()
+
+ // If the bloom filter misses, don't even bother with traversing the memory
+ // diff layers, reach straight into the bottom persistent disk layer
+ if origin != nil {
+ metrics.CounterInc(dl.snapmetrics.bloomStorageMissCount)
+
+ return origin.Storage(accountHash, storageHash)
+ }
+ // The bloom filter hit, start poking in the internal maps
+ return dl.storage(accountHash, storageHash, 0)
+}
+
+// storage is an internal version of Storage that skips the bloom filter checks
+// and uses the internal maps to try and retrieve the data. It's meant to be
+// used if a higher layer's bloom filter hit already.
+func (dl *diffLayer) storage(accountHash, storageHash types.Hash, depth int) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.Stale() {
+ return nil, ErrSnapshotStale
+ }
+ // If the account is known locally, try to resolve the slot locally
+ if storage, ok := dl.storageData[accountHash]; ok {
+ if data, ok := storage[storageHash]; ok {
+ metrics.CounterInc(dl.snapmetrics.dirtyStorageHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.dirtyStorageHitDepth, float64(depth))
+ metrics.CounterInc(dl.snapmetrics.bloomStorageTrueHitCount)
+
+ if n := len(data); n > 0 {
+ metrics.HistogramObserve(dl.snapmetrics.dirtyStorageReadSize, float64(n))
+ } else {
+ metrics.CounterInc(dl.snapmetrics.dirtyStorageInexCount)
+ }
+
+ return data, nil
+ }
+ }
+ // If the account is known locally, but deleted, return an empty slot
+ if _, ok := dl.destructSet[accountHash]; ok {
+ metrics.CounterInc(dl.snapmetrics.dirtyStorageHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.dirtyStorageHitDepth, float64(depth))
+ metrics.CounterInc(dl.snapmetrics.dirtyStorageInexCount)
+ metrics.CounterInc(dl.snapmetrics.bloomStorageTrueHitCount)
+
+ return nil, nil
+ }
+ // Storage slot unknown to this diff, resolve from parent
+ if diff, ok := dl.parent.(*diffLayer); ok {
+ return diff.storage(accountHash, storageHash, depth+1)
+ }
+
+ // Failed to resolve through diff layers, mark a bloom error and use the disk
+ metrics.CounterInc(dl.snapmetrics.bloomStorageFalseHitCount)
+
+ return dl.parent.Storage(accountHash, storageHash)
+}
+
+// Update creates a new layer on top of the existing snapshot diff tree with
+// the specified data items.
+func (dl *diffLayer) Update(
+ blockRoot types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ logger kvdb.Logger,
+) *diffLayer {
+ return newDiffLayer(dl, blockRoot, destructs, accounts, storage, logger, dl.snapmetrics)
+}
+
+// AccountList returns a sorted list of all accounts in this diffLayer, including
+// the deleted ones.
+//
+// Note, the returned slice is not a copy, so do not modify it.
+func (dl *diffLayer) AccountList() []types.Hash {
+ // If an old list already exists, return it
+ dl.lock.RLock()
+
+ list := dl.accountList
+
+ dl.lock.RUnlock()
+
+ if list != nil {
+ return list
+ }
+
+ // No old sorted account list exists, generate a new one
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ dl.accountList = make([]types.Hash, 0, len(dl.destructSet)+len(dl.accountData))
+
+ for hash := range dl.accountData {
+ dl.accountList = append(dl.accountList, hash)
+ }
+
+ for hash := range dl.destructSet {
+ if _, ok := dl.accountData[hash]; !ok {
+ dl.accountList = append(dl.accountList, hash)
+ }
+ }
+
+ sort.Sort(hashes(dl.accountList))
+
+ dl.memory += uint64(len(dl.accountList) * types.HashLength)
+
+ return dl.accountList
+}
+
+// StorageList returns a sorted list of all storage slot hashes in this diffLayer
+// for the given account. If the whole storage is destructed in this layer, then
+// an additional flag *destructed = true* will be returned, otherwise the flag is
+// false. Besides, the returned list will include the hash of deleted storage slot.
+// Note a special case is an account is deleted in a prior tx but is recreated in
+// the following tx with some storage slots set. In this case the returned list is
+// not empty but the flag is true.
+//
+// Note, the returned slice is not a copy, so do not modify it.
+func (dl *diffLayer) StorageList(accountHash types.Hash) ([]types.Hash, bool) {
+ dl.lock.RLock()
+
+ _, destructed := dl.destructSet[accountHash]
+
+ if _, ok := dl.storageData[accountHash]; !ok {
+ // Account not tracked by this layer
+ dl.lock.RUnlock()
+
+ return nil, destructed
+ }
+
+ // If an old list already exists, return it
+ if list, exist := dl.storageList[accountHash]; exist {
+ dl.lock.RUnlock()
+
+ return list, destructed // the cached list can't be nil
+ }
+
+ dl.lock.RUnlock()
+
+ // No old sorted account list exists, generate a new one
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ storageMap := dl.storageData[accountHash]
+ storageList := make([]types.Hash, 0, len(storageMap))
+
+ for k := range storageMap {
+ storageList = append(storageList, k)
+ }
+
+ sort.Sort(hashes(storageList))
+
+ dl.storageList[accountHash] = storageList
+ dl.memory += uint64(len(dl.storageList)*types.HashLength + types.HashLength)
+
+ return storageList, destructed
+}
+
+// flatten pushes all data from this point downwards, flattening everything into
+// a single diff at the bottom. Since usually the lowermost diff is the largest,
+// the flattening builds up from there in reverse.
+func (dl *diffLayer) flatten() snapshot {
+ // If the parent is not diff, we're the first in line, return unmodified
+ parent, ok := dl.parent.(*diffLayer)
+ if !ok {
+ return dl
+ }
+
+ // Parent is a diff, flatten it first (note, apart from weird corned cases,
+ // flatten will realistically only ever merge 1 layer, so there's no need to
+ // be smarter about grouping flattens together).
+ parent, _ = parent.flatten().(*diffLayer)
+
+ parent.lock.Lock()
+ defer parent.lock.Unlock()
+
+ // Before actually writing all our data to the parent, first ensure that the
+ // parent hasn't been 'corrupted' by someone else already flattening into it
+ if atomic.SwapUint32(&parent.stale, 1) != 0 {
+ // we've flattened into the same parent from two children, boo
+ panic("parent diff layer is stale")
+ }
+
+ // Overwrite all the updated accounts blindly, merge the sorted list
+ for hash := range dl.destructSet {
+ parent.destructSet[hash] = struct{}{}
+ delete(parent.accountData, hash)
+ delete(parent.storageData, hash)
+ }
+
+ for hash, data := range dl.accountData {
+ parent.accountData[hash] = data
+ }
+
+ // Overwrite all the updated storage slots (individually)
+ for accountHash, storage := range dl.storageData {
+ // If storage didn't exist (or was deleted) in the parent, overwrite blindly
+ if _, ok := parent.storageData[accountHash]; !ok {
+ parent.storageData[accountHash] = storage
+
+ continue
+ }
+ // Storage exists in both parent and child, merge the slots
+ comboData := parent.storageData[accountHash]
+
+ for storageHash, data := range storage {
+ comboData[storageHash] = data
+ }
+ }
+
+ // Return the combo parent
+ return &diffLayer{
+ parent: parent.parent,
+ origin: parent.origin,
+ root: dl.root,
+ destructSet: parent.destructSet,
+ accountData: parent.accountData,
+ storageData: parent.storageData,
+ storageList: make(map[types.Hash][]types.Hash),
+ diffed: dl.diffed,
+ memory: parent.memory + dl.memory,
+ logger: dl.logger,
+ snapmetrics: dl.snapmetrics,
+ }
+}
+
+// rebloom discards the layer's current bloom and rebuilds it from scratch based
+// on the parent's and the local diffs.
+func (dl *diffLayer) rebloom(origin *diskLayer) {
+ dl.lock.Lock()
+ defer dl.lock.Unlock()
+
+ defer func(start time.Time) {
+ metrics.HistogramObserve(dl.snapmetrics.bloomIndexNanoseconds, float64(time.Since(start).Nanoseconds()))
+ }(time.Now())
+
+ // Inject the new origin that triggered the rebloom
+ dl.origin = origin
+
+ // Retrieve the parent bloom or create a fresh empty one
+ if parent, ok := dl.parent.(*diffLayer); ok {
+ parent.lock.RLock()
+ dl.diffed, _ = parent.diffed.Copy()
+ parent.lock.RUnlock()
+ } else {
+ dl.diffed, _ = bloomfilter.New(uint64(bloomSize), uint64(bloomFuncs))
+ }
+
+ // Iterate over all the accounts and storage slots and index them
+ for hash := range dl.destructSet {
+ dl.diffed.Add(destructBloomHasher(hash))
+ }
+
+ for hash := range dl.accountData {
+ dl.diffed.Add(accountBloomHasher(hash))
+ }
+
+ for accountHash, slots := range dl.storageData {
+ for storageHash := range slots {
+ dl.diffed.Add(storageBloomHasher{accountHash, storageHash})
+ }
+ }
+
+ // Calculate the current false positive rate and update the error rate meter.
+ // This is a bit cheating because subsequent layers will overwrite it, but it
+ // should be fine, we're only interested in ballpark figures.
+ k := float64(dl.diffed.K())
+ n := float64(dl.diffed.N())
+ m := float64(dl.diffed.M())
+
+ metrics.SetGauge(dl.snapmetrics.bloomErrorCount, math.Pow(1.0-math.Exp((-k)*(n+0.5)/(m-1)), k))
+}
+
+// destructBloomHasher is a wrapper around a types.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert a destruct
+// event into a 64 bit mini hash.
+type destructBloomHasher types.Hash
+
+func (h destructBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h destructBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h destructBloomHasher) Reset() { panic("not implemented") }
+func (h destructBloomHasher) BlockSize() int { panic("not implemented") }
+func (h destructBloomHasher) Size() int { return 8 }
+func (h destructBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[bloomDestructHasherOffset : bloomDestructHasherOffset+8])
+}
+
+// accountBloomHasher is a wrapper around a types.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert an account
+// hash into a 64 bit mini hash.
+type accountBloomHasher types.Hash
+
+func (h accountBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h accountBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h accountBloomHasher) Reset() { panic("not implemented") }
+func (h accountBloomHasher) BlockSize() int { panic("not implemented") }
+func (h accountBloomHasher) Size() int { return 8 }
+func (h accountBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[bloomAccountHasherOffset : bloomAccountHasherOffset+8])
+}
+
+// storageBloomHasher is a wrapper around a [2]types.Hash to satisfy the interface
+// API requirements of the bloom library used. It's used to convert an account
+// hash into a 64 bit mini hash.
+type storageBloomHasher [2]types.Hash
+
+func (h storageBloomHasher) Write(p []byte) (n int, err error) { panic("not implemented") }
+func (h storageBloomHasher) Sum(b []byte) []byte { panic("not implemented") }
+func (h storageBloomHasher) Reset() { panic("not implemented") }
+func (h storageBloomHasher) BlockSize() int { panic("not implemented") }
+func (h storageBloomHasher) Size() int { return 8 }
+func (h storageBloomHasher) Sum64() uint64 {
+ return binary.BigEndian.Uint64(h[0][bloomStorageHasherOffset:bloomStorageHasherOffset+8]) ^
+ binary.BigEndian.Uint64(h[1][bloomStorageHasherOffset:bloomStorageHasherOffset+8])
+}
diff --git a/state/snapshot/difflayer_test.go b/state/snapshot/difflayer_test.go
new file mode 100644
index 0000000000..43d5824484
--- /dev/null
+++ b/state/snapshot/difflayer_test.go
@@ -0,0 +1,459 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/keccak"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+func copyDestructs(destructs map[types.Hash]struct{}) map[types.Hash]struct{} {
+ cp := make(map[types.Hash]struct{})
+ for hash := range destructs {
+ cp[hash] = struct{}{}
+ }
+
+ return cp
+}
+
+func copyAccounts(accounts map[types.Hash][]byte) map[types.Hash][]byte {
+ cp := make(map[types.Hash][]byte)
+ for hash, blob := range accounts {
+ cp[hash] = blob
+ }
+
+ return cp
+}
+
+func copyStorage(storage map[types.Hash]map[types.Hash][]byte) map[types.Hash]map[types.Hash][]byte {
+ cp := make(map[types.Hash]map[types.Hash][]byte)
+ for accHash, slots := range storage {
+ cp[accHash] = make(map[types.Hash][]byte)
+ for slotHash, blob := range slots {
+ cp[accHash][slotHash] = blob
+ }
+ }
+
+ return cp
+}
+
+// TestMergeBasics tests some simple merges
+func TestMergeBasics(t *testing.T) {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+
+ // Fill up a parent
+ for i := 0; i < 100; i++ {
+ h := randomHash()
+ data := randomAccount()
+
+ accounts[h] = data
+
+ if rand.Intn(4) == 0 {
+ destructs[h] = struct{}{}
+ }
+
+ if rand.Intn(2) == 0 {
+ accStorage := make(map[types.Hash][]byte)
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[h] = accStorage
+ }
+ }
+
+ logger := hclog.NewNullLogger()
+ snapmetrics := NilMetrics()
+
+ // Add some (identical) layers on top
+ parent := newDiffLayer(emptyLayer(), types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), logger, snapmetrics)
+ child := newDiffLayer(parent, types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), logger, snapmetrics)
+ child = newDiffLayer(child, types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), logger, snapmetrics)
+ child = newDiffLayer(child, types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), logger, snapmetrics)
+ child = newDiffLayer(child, types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), logger, snapmetrics)
+ // And flatten
+ merged, _ := (child.flatten()).(*diffLayer)
+
+ { // Check account lists
+ if have, want := len(merged.accountList), 0; have != want {
+ t.Errorf("accountList wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.AccountList()), len(accounts); have != want {
+ t.Errorf("AccountList() wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.accountList), len(accounts); have != want {
+ t.Errorf("accountList [2] wrong: have %v, want %v", have, want)
+ }
+ }
+ { // Check account drops
+ if have, want := len(merged.destructSet), len(destructs); have != want {
+ t.Errorf("accountDrop wrong: have %v, want %v", have, want)
+ }
+ }
+ { // Check storage lists
+ i := 0
+ for aHash, sMap := range storage {
+ if have, want := len(merged.storageList), i; have != want {
+ t.Errorf("[1] storageList wrong: have %v, want %v", have, want)
+ }
+ list, _ := merged.StorageList(aHash)
+ if have, want := len(list), len(sMap); have != want {
+ t.Errorf("[2] StorageList() wrong: have %v, want %v", have, want)
+ }
+ if have, want := len(merged.storageList[aHash]), len(sMap); have != want {
+ t.Errorf("storageList wrong: have %v, want %v", have, want)
+ }
+ i++
+ }
+ }
+}
+
+// TestMergeDelete tests some deletion
+func TestMergeDelete(t *testing.T) {
+ var (
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+
+ // Fill up a parent
+ h1 := types.StringToHash("0x01")
+ h2 := types.StringToHash("0x02")
+
+ flipDrops := func() map[types.Hash]struct{} {
+ return map[types.Hash]struct{}{
+ h2: {},
+ }
+ }
+ flipAccs := func() map[types.Hash][]byte {
+ return map[types.Hash][]byte{
+ h1: randomAccount(),
+ }
+ }
+ flopDrops := func() map[types.Hash]struct{} {
+ return map[types.Hash]struct{}{
+ h1: {},
+ }
+ }
+ flopAccs := func() map[types.Hash][]byte {
+ return map[types.Hash][]byte{
+ h2: randomAccount(),
+ }
+ }
+ logger := hclog.NewNullLogger()
+ snapmetrics := NilMetrics()
+
+ // Add some flipAccs-flopping layers on top
+ parent := newDiffLayer(emptyLayer(), types.Hash{}, flipDrops(), flipAccs(), storage, logger, snapmetrics)
+ child := parent.Update(types.Hash{}, flopDrops(), flopAccs(), storage, logger)
+ child = child.Update(types.Hash{}, flipDrops(), flipAccs(), storage, logger)
+ child = child.Update(types.Hash{}, flopDrops(), flopAccs(), storage, logger)
+ child = child.Update(types.Hash{}, flipDrops(), flipAccs(), storage, logger)
+ child = child.Update(types.Hash{}, flopDrops(), flopAccs(), storage, logger)
+ child = child.Update(types.Hash{}, flipDrops(), flipAccs(), storage, logger)
+
+ if data, _ := child.Account(h1); data == nil {
+ t.Errorf("last diff layer: expected %x account to be non-nil", h1)
+ }
+
+ if data, _ := child.Account(h2); data != nil {
+ t.Errorf("last diff layer: expected %x account to be nil", h2)
+ }
+
+ if _, ok := child.destructSet[h1]; ok {
+ t.Errorf("last diff layer: expected %x drop to be missing", h1)
+ }
+
+ if _, ok := child.destructSet[h2]; !ok {
+ t.Errorf("last diff layer: expected %x drop to be present", h1)
+ }
+
+ // And flatten
+ merged, _ := (child.flatten()).(*diffLayer)
+
+ // If we add more granular metering of memory, we can enable this again,
+ // but it's not implemented for now
+ //if have, want := merged.memory, child.memory; have != want {
+ // t.Errorf("mem wrong: have %d, want %d", have, want)
+ //}
+
+ if data, _ := merged.Account(h1); data == nil {
+ t.Errorf("merged layer: expected %x account to be non-nil", h1)
+ }
+
+ if data, _ := merged.Account(h2); data != nil {
+ t.Errorf("merged layer: expected %x account to be nil", h2)
+ }
+
+ if _, ok := merged.destructSet[h1]; !ok { // Note, drops stay alive until persisted to disk!
+ t.Errorf("merged diff layer: expected %x drop to be present", h1)
+ }
+
+ if _, ok := merged.destructSet[h2]; !ok { // Note, drops stay alive until persisted to disk!
+ t.Errorf("merged diff layer: expected %x drop to be present", h1)
+ }
+}
+
+// This tests that if we create a new account, and set a slot, and then merge
+// it, the lists will be correct.
+func TestInsertAndMerge(t *testing.T) {
+ // Fill up a parent
+ var (
+ acc = types.StringToHash("0x01")
+ slot = types.StringToHash("0x02")
+ parent *diffLayer
+ child *diffLayer
+ logger = hclog.NewNullLogger()
+ snapmetrics = NilMetrics()
+ )
+
+ {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+ parent = newDiffLayer(emptyLayer(), types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+ {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+ accounts[acc] = randomAccount()
+ storage[acc] = make(map[types.Hash][]byte)
+ storage[acc][slot] = []byte{0x01}
+ child = newDiffLayer(parent, types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+
+ // And flatten
+ merged, _ := (child.flatten()).(*diffLayer)
+
+ {
+ // Check that slot value is present
+ have, _ := merged.Storage(acc, slot)
+ if want := []byte{0x01}; !bytes.Equal(have, want) {
+ t.Errorf("merged slot value wrong: have %x, want %x", have, want)
+ }
+ }
+}
+
+func emptyLayer() *diskLayer {
+ return &diskLayer{
+ diskdb: memorydb.New(),
+ cache: fastcache.New(500 * 1024),
+ logger: hclog.NewNullLogger(),
+ snapmetrics: NilMetrics(),
+ }
+}
+
+// BenchmarkSearch checks how long it takes to find a non-existing key
+// BenchmarkSearch-6 200000 10481 ns/op (1K per layer)
+// BenchmarkSearch-6 200000 10760 ns/op (10K per layer)
+// BenchmarkSearch-6 100000 17866 ns/op
+//
+// BenchmarkSearch-6 500000 3723 ns/op (10k per layer, only top-level RLock()
+func BenchmarkSearch(b *testing.B) {
+ // First, we set up 128 diff layers, with 1K items each
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ logger = hclog.NewNullLogger()
+ snapmetrics = NilMetrics()
+ )
+
+ for i := 0; i < 10000; i++ {
+ accounts[randomHash()] = randomAccount()
+ }
+
+ return newDiffLayer(parent, types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+
+ var layer snapshot = emptyLayer()
+ for i := 0; i < 128; i++ {
+ layer = fill(layer)
+ }
+
+ key := keccak.Keccak256(nil, []byte{0x13, 0x38})
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ layer.AccountRLP(types.BytesToHash(key))
+ }
+}
+
+// BenchmarkSearchSlot checks how long it takes to find a non-existing key
+// - Number of layers: 128
+// - Each layers contains the account, with a couple of storage slots
+// BenchmarkSearchSlot-6 100000 14554 ns/op
+// BenchmarkSearchSlot-6 100000 22254 ns/op (when checking parent root using mutex)
+// BenchmarkSearchSlot-6 100000 14551 ns/op (when checking parent number using atomic)
+// With bloom filter:
+// BenchmarkSearchSlot-6 3467835 351 ns/op
+func BenchmarkSearchSlot(b *testing.B) {
+ // First, we set up 128 diff layers, with 1K items each
+ accountKey := types.BytesToHash(keccak.Keccak256(nil, []byte{0x13, 0x37}))
+ storageKey := types.BytesToHash(keccak.Keccak256(nil, []byte{0x13, 0x37}))
+ accountRLP := randomAccount()
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ logger = hclog.NewNullLogger()
+ snapmetrics = NilMetrics()
+ )
+
+ accounts[accountKey] = accountRLP
+ accStorage := make(map[types.Hash][]byte)
+
+ for i := 0; i < 5; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[accountKey] = accStorage
+ }
+
+ return newDiffLayer(parent, types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+
+ var layer snapshot = emptyLayer()
+ for i := 0; i < 128; i++ {
+ layer = fill(layer)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ layer.Storage(accountKey, storageKey)
+ }
+}
+
+// With accountList and sorting
+// BenchmarkFlatten-6 50 29890856 ns/op
+//
+// Without sorting and tracking accountList
+// BenchmarkFlatten-6 300 5511511 ns/op
+func BenchmarkFlatten(b *testing.B) {
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ logger = hclog.NewNullLogger()
+ snapmetrics = NilMetrics()
+ )
+
+ for i := 0; i < 100; i++ {
+ accountKey := randomHash()
+ accounts[accountKey] = randomAccount()
+ accStorage := make(map[types.Hash][]byte)
+
+ for i := 0; i < 20; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ }
+
+ storage[accountKey] = accStorage
+ }
+
+ return newDiffLayer(parent, types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+
+ var layer snapshot = emptyLayer()
+ for i := 1; i < 128; i++ {
+ layer = fill(layer)
+ }
+
+ b.StartTimer()
+
+ for i := 1; i < 128; i++ {
+ dl, ok := layer.(*diffLayer)
+ if !ok {
+ break
+ }
+
+ layer = dl.flatten()
+ }
+
+ b.StopTimer()
+ }
+}
+
+// This test writes ~324M of diff layers to disk, spread over
+// - 128 individual layers,
+// - each with 200 accounts
+// - containing 200 slots
+//
+// BenchmarkJournal-6 1 1471373923 ns/ops
+// BenchmarkJournal-6 1 1208083335 ns/op // bufio writer
+func BenchmarkJournal(b *testing.B) {
+ fill := func(parent snapshot) *diffLayer {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ logger = hclog.NewNullLogger()
+ snapmetrics = NilMetrics()
+ )
+
+ for i := 0; i < 200; i++ {
+ accountKey := randomHash()
+ accounts[accountKey] = randomAccount()
+ accStorage := make(map[types.Hash][]byte)
+
+ for i := 0; i < 200; i++ {
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ }
+
+ storage[accountKey] = accStorage
+ }
+
+ return newDiffLayer(parent, types.Hash{}, destructs, accounts, storage, logger, snapmetrics)
+ }
+
+ layer := snapshot(emptyLayer())
+ for i := 1; i < 128; i++ {
+ layer = fill(layer)
+ }
+
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ layer.Journal(new(bytes.Buffer))
+ }
+}
diff --git a/state/snapshot/disklayer.go b/state/snapshot/disklayer.go
new file mode 100644
index 0000000000..fc705a8d45
--- /dev/null
+++ b/state/snapshot/disklayer.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "math/big"
+ "sync"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// diskLayer is a low level persistent snapshot built on top of a key-value store.
+type diskLayer struct {
+ diskdb kvdb.KVBatchStorage // Key-value store containing the base snapshot
+ triedb *trie.Database // Trie node cache for reconstruction purposes
+ cache *fastcache.Cache // Cache to avoid hitting the disk for direct access
+
+ root types.Hash // Root hash of the base snapshot
+ stale bool // Signals that the layer became stale (state progressed)
+
+ genMarker []byte // Marker for the state that's indexed during initial layer generation
+ genPending chan struct{} // Notification channel when generation is done (test synchronicity)
+ genAbort chan chan *generatorStats // Notification channel to abort generating the snapshot in this layer
+
+ lock sync.RWMutex
+
+ logger kvdb.Logger
+ snapmetrics *Metrics
+}
+
+// Root returns root hash for which this snapshot was made.
+func (dl *diskLayer) Root() types.Hash {
+ return dl.root
+}
+
+// Parent always returns nil as there's no layer below the disk.
+func (dl *diskLayer) Parent() snapshot {
+ return nil
+}
+
+// Stale return whether this layer has become stale (was flattened across) or if
+// it's still live.
+func (dl *diskLayer) Stale() bool {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ return dl.stale
+}
+
+// Account directly retrieves the account associated with a particular hash in
+// the snapshot slim data format.
+func (dl *diskLayer) Account(hash types.Hash) (*stypes.Account, error) {
+ data, err := dl.AccountRLP(hash)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(data) == 0 { // can be both nil and []byte{}
+ return nil, nil
+ }
+
+ var slimAccount Account
+ if err := rlp.DecodeBytes(data, &slimAccount); err != nil {
+ panic(err)
+ }
+
+ // copy balance to heap
+ balance := new(big.Int).Set(slimAccount.Balance)
+
+ return &stypes.Account{
+ Nonce: slimAccount.Nonce,
+ Balance: balance,
+ StorageRoot: types.BytesToHash(slimAccount.Root),
+ CodeHash: slimAccount.CodeHash,
+ }, nil
+}
+
+// AccountRLP directly retrieves the account RLP associated with a particular
+// hash in the snapshot slim data format.
+func (dl *diskLayer) AccountRLP(hash types.Hash) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.stale {
+ return nil, ErrSnapshotStale
+ }
+ // If the layer is being generated, ensure the requested hash has already been
+ // covered by the generator.
+ if dl.genMarker != nil && bytes.Compare(hash[:], dl.genMarker) > 0 {
+ return nil, ErrNotCoveredYet
+ }
+
+ // If we're in the disk layer, all diff layers missed
+ metrics.CounterInc(dl.snapmetrics.dirtyAccountMissCount)
+
+ // Try to retrieve the account from the memory cache
+ if blob, found := dl.cache.HasGet(nil, hash[:]); found {
+ metrics.CounterInc(dl.snapmetrics.cleanAccountHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.cleanAccountReadSize, float64(len(blob)))
+
+ return blob, nil
+ }
+
+ // Cache doesn't contain account, pull from disk and cache for later
+ blob := rawdb.ReadAccountSnapshot(dl.diskdb, hash)
+ dl.cache.Set(hash[:], blob)
+
+ metrics.CounterInc(dl.snapmetrics.cleanAccountMissCount)
+ // write or inex
+ if n := len(blob); n > 0 {
+ metrics.HistogramObserve(dl.snapmetrics.cleanAccountWriteSize, float64(n))
+ } else {
+ metrics.CounterInc(dl.snapmetrics.cleanAccountInexCount)
+ }
+
+ return blob, nil
+}
+
+// Storage directly retrieves the storage data associated with a particular hash,
+// within a particular account.
+func (dl *diskLayer) Storage(accountHash, storageHash types.Hash) ([]byte, error) {
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ // If the layer was flattened into, consider it invalid (any live reference to
+ // the original should be marked as unusable).
+ if dl.stale {
+ return nil, ErrSnapshotStale
+ }
+
+ key := append(accountHash[:], storageHash[:]...)
+
+ // If the layer is being generated, ensure the requested hash has already been
+ // covered by the generator.
+ if dl.genMarker != nil && bytes.Compare(key, dl.genMarker) > 0 {
+ return nil, ErrNotCoveredYet
+ }
+
+ // If we're in the disk layer, all diff layers missed
+ metrics.CounterInc(dl.snapmetrics.dirtyStorageMissCount)
+
+ // Try to retrieve the storage slot from the memory cache
+ if blob, found := dl.cache.HasGet(nil, key); found {
+ metrics.CounterInc(dl.snapmetrics.cleanStorageHitCount)
+ metrics.HistogramObserve(dl.snapmetrics.cleanStorageReadSize, float64(len(blob)))
+
+ return blob, nil
+ }
+ // Cache doesn't contain storage slot, pull from disk and cache for later
+ blob := rawdb.ReadStorageSnapshot(dl.diskdb, accountHash, storageHash)
+ dl.cache.Set(key, blob)
+
+ metrics.CounterInc(dl.snapmetrics.cleanStorageMissCount)
+ // write or inex
+ if n := len(blob); n > 0 {
+ metrics.HistogramObserve(dl.snapmetrics.cleanStorageWriteSize, float64(n))
+ } else {
+ metrics.CounterInc(dl.snapmetrics.cleanStorageInexCount)
+ }
+
+ return blob, nil
+}
+
+// Update creates a new layer on top of the existing snapshot diff tree with
+// the specified data items. Note, the maps are retained by the method to avoid
+// copying everything.
+func (dl *diskLayer) Update(
+ blockHash types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ logger kvdb.Logger,
+) *diffLayer {
+ return newDiffLayer(dl, blockHash, destructs, accounts, storage, logger, dl.snapmetrics)
+}
diff --git a/state/snapshot/disklayer_test.go b/state/snapshot/disklayer_test.go
new file mode 100644
index 0000000000..522ff1af79
--- /dev/null
+++ b/state/snapshot/disklayer_test.go
@@ -0,0 +1,647 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+// reverse reverses the contents of a byte slice. It's used to update random accs
+// with deterministic changes.
+func reverse(blob []byte) []byte {
+ res := make([]byte, len(blob))
+ for i, b := range blob {
+ res[len(blob)-1-i] = b
+ }
+
+ return res
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values.
+func TestDiskMerge(t *testing.T) {
+ // Create some accounts in the disk layer
+ db := memorydb.New()
+
+ var (
+ accNoModNoCache = types.Hash{0x1}
+ accNoModCache = types.Hash{0x2}
+ accModNoCache = types.Hash{0x3}
+ accModCache = types.Hash{0x4}
+ accDelNoCache = types.Hash{0x5}
+ accDelCache = types.Hash{0x6}
+ conNoModNoCache = types.Hash{0x7}
+ conNoModNoCacheSlot = types.Hash{0x70}
+ conNoModCache = types.Hash{0x8}
+ conNoModCacheSlot = types.Hash{0x80}
+ conModNoCache = types.Hash{0x9}
+ conModNoCacheSlot = types.Hash{0x90}
+ conModCache = types.Hash{0xa}
+ conModCacheSlot = types.Hash{0xa0}
+ conDelNoCache = types.Hash{0xb}
+ conDelNoCacheSlot = types.Hash{0xb0}
+ conDelCache = types.Hash{0xc}
+ conDelCacheSlot = types.Hash{0xc0}
+ conNukeNoCache = types.Hash{0xd}
+ conNukeNoCacheSlot = types.Hash{0xd0}
+ conNukeCache = types.Hash{0xe}
+ conNukeCacheSlot = types.Hash{0xe0}
+ baseRoot = randomHash()
+ diffRoot = randomHash()
+ )
+
+ rawdb.WriteAccountSnapshot(db, accNoModNoCache, accNoModNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accNoModCache, accNoModCache[:])
+ rawdb.WriteAccountSnapshot(db, accModNoCache, accModNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accModCache, accModCache[:])
+ rawdb.WriteAccountSnapshot(db, accDelNoCache, accDelNoCache[:])
+ rawdb.WriteAccountSnapshot(db, accDelCache, accDelCache[:])
+
+ rawdb.WriteAccountSnapshot(db, conNoModNoCache, conNoModNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conNoModCache, conNoModCache[:])
+ rawdb.WriteStorageSnapshot(db, conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conModNoCache, conModNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conModCache, conModCache[:])
+ rawdb.WriteStorageSnapshot(db, conModCache, conModCacheSlot, conModCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conDelNoCache, conDelNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conDelCache, conDelCache[:])
+ rawdb.WriteStorageSnapshot(db, conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+
+ rawdb.WriteAccountSnapshot(db, conNukeNoCache, conNukeNoCache[:])
+ rawdb.WriteStorageSnapshot(db, conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:])
+ rawdb.WriteAccountSnapshot(db, conNukeCache, conNukeCache[:])
+ rawdb.WriteStorageSnapshot(db, conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ // Create a disk layer based on the above and cache in some data
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ logger: hclog.NewNullLogger(),
+ snapmetrics: NilMetrics(),
+ },
+ },
+ }
+
+ base := snaps.Snapshot(baseRoot)
+ base.AccountRLP(accNoModCache)
+ base.AccountRLP(accModCache)
+ base.AccountRLP(accDelCache)
+ base.Storage(conNoModCache, conNoModCacheSlot)
+ base.Storage(conModCache, conModCacheSlot)
+ base.Storage(conDelCache, conDelCacheSlot)
+ base.Storage(conNukeCache, conNukeCacheSlot)
+
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(
+ diffRoot,
+ baseRoot,
+ map[types.Hash]struct{}{
+ accDelNoCache: {},
+ accDelCache: {},
+ conNukeNoCache: {},
+ conNukeCache: {},
+ },
+ map[types.Hash][]byte{
+ accModNoCache: reverse(accModNoCache[:]),
+ accModCache: reverse(accModCache[:]),
+ },
+ map[types.Hash]map[types.Hash][]byte{
+ conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])},
+ conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])},
+ conDelNoCache: {conDelNoCacheSlot: nil},
+ conDelCache: {conDelCacheSlot: nil},
+ },
+ hclog.NewNullLogger(),
+ ); err != nil {
+ t.Fatalf("failed to update snapshot tree: %v", err)
+ }
+
+ if err := snaps.Cap(diffRoot, 0); err != nil {
+ t.Fatalf("failed to flatten snapshot tree: %v", err)
+ }
+
+ // Retrieve all the data through the disk layer and validate it
+ base = snaps.Snapshot(diffRoot)
+ if _, ok := base.(*diskLayer); !ok {
+ t.Fatalf("update not flattend into the disk layer")
+ }
+
+ // assertAccount ensures that an account matches the given blob.
+ assertAccount := func(account types.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.AccountRLP(account)
+ if err != nil {
+ t.Errorf("account access (%s) failed: %v", account, err)
+ } else if !bytes.Equal(blob, data) {
+ t.Errorf("account access (%s) mismatch: have %x, want %x", account, blob, data)
+ }
+ }
+
+ assertAccount(accNoModNoCache, accNoModNoCache[:])
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertAccount(accModCache, reverse(accModCache[:]))
+ assertAccount(accDelNoCache, nil)
+ assertAccount(accDelCache, nil)
+
+ // assertStorage ensures that a storage slot matches the given blob.
+ assertStorage := func(account types.Hash, slot types.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.Storage(account, slot)
+ if err != nil {
+ t.Errorf("storage access (%s:%s) failed: %v", account, slot, err)
+ } else if !bytes.Equal(blob, data) {
+ t.Errorf("storage access (%s:%s) mismatch: have %x, want %x", account, slot, blob, data)
+ }
+ }
+
+ assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertStorage(conDelCache, conDelCacheSlot, nil)
+ assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertStorage(conNukeCache, conNukeCacheSlot, nil)
+
+ // Retrieve all the data directly from the database and validate it
+
+ // assertDatabaseAccount ensures that an account from the database matches the given blob.
+ assertDatabaseAccount := func(account types.Hash, data []byte) {
+ t.Helper()
+ if blob := rawdb.ReadAccountSnapshot(db, account); !bytes.Equal(blob, data) {
+ t.Errorf("account database access (%s) mismatch: have %x, want %x", account, blob, data)
+ }
+ }
+
+ assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:])
+ assertDatabaseAccount(accNoModCache, accNoModCache[:])
+ assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertDatabaseAccount(accModCache, reverse(accModCache[:]))
+ assertDatabaseAccount(accDelNoCache, nil)
+ assertDatabaseAccount(accDelCache, nil)
+
+ // assertDatabaseStorage ensures that a storage slot from the database matches the given blob.
+ assertDatabaseStorage := func(account types.Hash, slot types.Hash, data []byte) {
+ t.Helper()
+ if blob := rawdb.ReadStorageSnapshot(db, account, slot); !bytes.Equal(blob, data) {
+ t.Errorf("storage database access (%s:%s) mismatch: have %x, want %x", account, slot, blob, data)
+ }
+ }
+
+ assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertDatabaseStorage(conDelCache, conDelCacheSlot, nil)
+ assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil)
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values, discarding anything
+// after the in-progress generation marker.
+func TestDiskPartialMerge(t *testing.T) {
+ // Iterate the test a few times to ensure we pick various internal orderings
+ // for the data slots as well as the progress marker.
+ for i := 0; i < 1024; i++ {
+ // Create some accounts in the disk layer
+ db := memorydb.New()
+
+ var (
+ accNoModNoCache = randomHash()
+ accNoModCache = randomHash()
+ accModNoCache = randomHash()
+ accModCache = randomHash()
+ accDelNoCache = randomHash()
+ accDelCache = randomHash()
+ conNoModNoCache = randomHash()
+ conNoModNoCacheSlot = randomHash()
+ conNoModCache = randomHash()
+ conNoModCacheSlot = randomHash()
+ conModNoCache = randomHash()
+ conModNoCacheSlot = randomHash()
+ conModCache = randomHash()
+ conModCacheSlot = randomHash()
+ conDelNoCache = randomHash()
+ conDelNoCacheSlot = randomHash()
+ conDelCache = randomHash()
+ conDelCacheSlot = randomHash()
+ conNukeNoCache = randomHash()
+ conNukeNoCacheSlot = randomHash()
+ conNukeCache = randomHash()
+ conNukeCacheSlot = randomHash()
+ baseRoot = randomHash()
+ diffRoot = randomHash()
+ genMarker = append(randomHash().Bytes(), randomHash().Bytes()...)
+ )
+
+ // insertAccount injects an account into the database if it's after the
+ // generator marker, drops the op otherwise. This is needed to seed the
+ // database with a valid starting snapshot.
+ insertAccount := func(account types.Hash, data []byte) {
+ if bytes.Compare(account[:], genMarker) <= 0 {
+ rawdb.WriteAccountSnapshot(db, account, data[:])
+ }
+ }
+
+ insertAccount(accNoModNoCache, accNoModNoCache[:])
+ insertAccount(accNoModCache, accNoModCache[:])
+ insertAccount(accModNoCache, accModNoCache[:])
+ insertAccount(accModCache, accModCache[:])
+ insertAccount(accDelNoCache, accDelNoCache[:])
+ insertAccount(accDelCache, accDelCache[:])
+
+ // insertStorage injects a storage slot into the database if it's after
+ // the generator marker, drops the op otherwise. This is needed to seed
+ // the database with a valid starting snapshot.
+ insertStorage := func(account types.Hash, slot types.Hash, data []byte) {
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 {
+ rawdb.WriteStorageSnapshot(db, account, slot, data[:])
+ }
+ }
+
+ insertAccount(conNoModNoCache, conNoModNoCache[:])
+ insertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ insertAccount(conNoModCache, conNoModCache[:])
+ insertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ insertAccount(conModNoCache, conModNoCache[:])
+ insertStorage(conModNoCache, conModNoCacheSlot, conModNoCacheSlot[:])
+ insertAccount(conModCache, conModCache[:])
+ insertStorage(conModCache, conModCacheSlot, conModCacheSlot[:])
+ insertAccount(conDelNoCache, conDelNoCache[:])
+ insertStorage(conDelNoCache, conDelNoCacheSlot, conDelNoCacheSlot[:])
+ insertAccount(conDelCache, conDelCache[:])
+ insertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+
+ insertAccount(conNukeNoCache, conNukeNoCache[:])
+ insertStorage(conNukeNoCache, conNukeNoCacheSlot, conNukeNoCacheSlot[:])
+ insertAccount(conNukeCache, conNukeCache[:])
+ insertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ // Create a disk layer based on the above using a random progress marker
+ // and cache in some data.
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ logger: hclog.NewNullLogger(),
+ snapmetrics: NilMetrics(),
+ },
+ },
+ }
+ snaps.layers[baseRoot].(*diskLayer).genMarker = genMarker
+ base := snaps.Snapshot(baseRoot)
+
+ // assertAccount ensures that an account matches the given blob if it's
+ // already covered by the disk snapshot, and errors out otherwise.
+ assertAccount := func(account types.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.AccountRLP(account)
+ if bytes.Compare(account[:], genMarker) > 0 && err != ErrNotCoveredYet {
+ t.Fatalf("test %d: post-marker (%s) account access (%s) succeeded: %x", i, genMarker, account, blob)
+ }
+ if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%s) account access (%s) mismatch: have %x, want %x", i, genMarker, account, blob, data)
+ }
+ }
+
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModCache, accModCache[:])
+ assertAccount(accDelCache, accDelCache[:])
+
+ // assertStorage ensures that a storage slot matches the given blob if
+ // it's already covered by the disk snapshot, and errors out otherwise.
+ assertStorage := func(account types.Hash, slot types.Hash, data []byte) {
+ t.Helper()
+ blob, err := base.Storage(account, slot)
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && err != ErrNotCoveredYet {
+ t.Fatalf("test %d: post-marker (%x) storage access (%s:%s) succeeded: %x", i, genMarker, account, slot, blob)
+ }
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) storage access (%s:%s) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data)
+ }
+ }
+
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModCache, conModCacheSlot, conModCacheSlot[:])
+ assertStorage(conDelCache, conDelCacheSlot, conDelCacheSlot[:])
+ assertStorage(conNukeCache, conNukeCacheSlot, conNukeCacheSlot[:])
+
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(
+ diffRoot,
+ baseRoot,
+ map[types.Hash]struct{}{
+ accDelNoCache: {},
+ accDelCache: {},
+ conNukeNoCache: {},
+ conNukeCache: {},
+ },
+ map[types.Hash][]byte{
+ accModNoCache: reverse(accModNoCache[:]),
+ accModCache: reverse(accModCache[:]),
+ },
+ map[types.Hash]map[types.Hash][]byte{
+ conModNoCache: {conModNoCacheSlot: reverse(conModNoCacheSlot[:])},
+ conModCache: {conModCacheSlot: reverse(conModCacheSlot[:])},
+ conDelNoCache: {conDelNoCacheSlot: nil},
+ conDelCache: {conDelCacheSlot: nil},
+ },
+ hclog.NewNullLogger(),
+ ); err != nil {
+ t.Fatalf("test %d: failed to update snapshot tree: %v", i, err)
+ }
+
+ if err := snaps.Cap(diffRoot, 0); err != nil {
+ t.Fatalf("test %d: failed to flatten snapshot tree: %v", i, err)
+ }
+
+ // Retrieve all the data through the disk layer and validate it
+ base = snaps.Snapshot(diffRoot)
+ if _, ok := base.(*diskLayer); !ok {
+ t.Fatalf("test %d: update not flattend into the disk layer", i)
+ }
+
+ assertAccount(accNoModNoCache, accNoModNoCache[:])
+ assertAccount(accNoModCache, accNoModCache[:])
+ assertAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertAccount(accModCache, reverse(accModCache[:]))
+ assertAccount(accDelNoCache, nil)
+ assertAccount(accDelCache, nil)
+
+ assertStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertStorage(conDelCache, conDelCacheSlot, nil)
+ assertStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertStorage(conNukeCache, conNukeCacheSlot, nil)
+
+ // Retrieve all the data directly from the database and validate it
+
+ // assertDatabaseAccount ensures that an account inside the database matches
+ // the given blob if it's already covered by the disk snapshot, and does not
+ // exist otherwise.
+ assertDatabaseAccount := func(account types.Hash, data []byte) {
+ t.Helper()
+ blob := rawdb.ReadAccountSnapshot(db, account)
+ if bytes.Compare(account[:], genMarker) > 0 && blob != nil {
+ t.Fatalf("test %d: post-marker (%x) account database access (%s) succeeded: %x", i, genMarker, account, blob)
+ }
+ if bytes.Compare(account[:], genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) account database access (%s) mismatch: have %x, want %x", i, genMarker, account, blob, data)
+ }
+ }
+
+ assertDatabaseAccount(accNoModNoCache, accNoModNoCache[:])
+ assertDatabaseAccount(accNoModCache, accNoModCache[:])
+ assertDatabaseAccount(accModNoCache, reverse(accModNoCache[:]))
+ assertDatabaseAccount(accModCache, reverse(accModCache[:]))
+ assertDatabaseAccount(accDelNoCache, nil)
+ assertDatabaseAccount(accDelCache, nil)
+
+ // assertDatabaseStorage ensures that a storage slot inside the database
+ // matches the given blob if it's already covered by the disk snapshot,
+ // and does not exist otherwise.
+ assertDatabaseStorage := func(account types.Hash, slot types.Hash, data []byte) {
+ t.Helper()
+ blob := rawdb.ReadStorageSnapshot(db, account, slot)
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) > 0 && blob != nil {
+ t.Fatalf("test %d: post-marker (%x) storage database access (%s:%s) succeeded: %x", i, genMarker, account, slot, blob)
+ }
+ if bytes.Compare(append(account[:], slot[:]...), genMarker) <= 0 && !bytes.Equal(blob, data) {
+ t.Fatalf("test %d: pre-marker (%x) storage database access (%s:%s) mismatch: have %x, want %x", i, genMarker, account, slot, blob, data)
+ }
+ }
+
+ assertDatabaseStorage(conNoModNoCache, conNoModNoCacheSlot, conNoModNoCacheSlot[:])
+ assertDatabaseStorage(conNoModCache, conNoModCacheSlot, conNoModCacheSlot[:])
+ assertDatabaseStorage(conModNoCache, conModNoCacheSlot, reverse(conModNoCacheSlot[:]))
+ assertDatabaseStorage(conModCache, conModCacheSlot, reverse(conModCacheSlot[:]))
+ assertDatabaseStorage(conDelNoCache, conDelNoCacheSlot, nil)
+ assertDatabaseStorage(conDelCache, conDelCacheSlot, nil)
+ assertDatabaseStorage(conNukeNoCache, conNukeNoCacheSlot, nil)
+ assertDatabaseStorage(conNukeCache, conNukeCacheSlot, nil)
+ }
+}
+
+// Tests that when the bottom-most diff layer is merged into the disk
+// layer whether the corresponding generator is persisted correctly.
+func TestDiskGeneratorPersistence(t *testing.T) {
+ var (
+ accOne = randomHash()
+ accTwo = randomHash()
+ accOneSlotOne = randomHash()
+ accOneSlotTwo = randomHash()
+
+ accThree = randomHash()
+ accThreeSlot = randomHash()
+ baseRoot = randomHash()
+ diffRoot = randomHash()
+ diffTwoRoot = randomHash()
+ genMarker = append(randomHash().Bytes(), randomHash().Bytes()...)
+ )
+
+ // Testing scenario 1, the disk layer is still under the construction.
+ db := rawdb.NewMemoryDatabase()
+
+ rawdb.WriteAccountSnapshot(db, accOne, accOne[:])
+ rawdb.WriteStorageSnapshot(db, accOne, accOneSlotOne, accOneSlotOne[:])
+ rawdb.WriteStorageSnapshot(db, accOne, accOneSlotTwo, accOneSlotTwo[:])
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ // Create a disk layer based on all above updates
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ genMarker: genMarker,
+ logger: hclog.NewNullLogger(),
+ snapmetrics: NilMetrics(),
+ },
+ },
+ }
+
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(
+ diffRoot,
+ baseRoot,
+ nil,
+ map[types.Hash][]byte{
+ accTwo: accTwo[:],
+ },
+ nil,
+ hclog.NewNullLogger(),
+ ); err != nil {
+ t.Fatalf("failed to update snapshot tree: %v", err)
+ }
+
+ if err := snaps.Cap(diffRoot, 0); err != nil {
+ t.Fatalf("failed to flatten snapshot tree: %v", err)
+ }
+
+ blob := rawdb.ReadSnapshotGenerator(db)
+
+ var generator journalGenerator
+ if err := rlp.DecodeBytes(blob, &generator); err != nil {
+ t.Fatalf("Failed to decode snapshot generator %v", err)
+ }
+
+ if !bytes.Equal(generator.Marker, genMarker) {
+ t.Fatalf("Generator marker is not matched")
+ }
+
+ // Test scenario 2, the disk layer is fully generated
+ // Modify or delete some accounts, flatten everything onto disk
+ if err := snaps.Update(
+ diffTwoRoot,
+ diffRoot,
+ nil,
+ map[types.Hash][]byte{
+ accThree: accThree.Bytes(),
+ },
+ map[types.Hash]map[types.Hash][]byte{
+ accThree: {accThreeSlot: accThreeSlot.Bytes()},
+ },
+ hclog.NewNullLogger(),
+ ); err != nil {
+ t.Fatalf("failed to update snapshot tree: %v", err)
+ }
+
+ diskLayer := snaps.layers[snaps.diskRoot()].(*diskLayer)
+ diskLayer.genMarker = nil // Construction finished
+
+ if err := snaps.Cap(diffTwoRoot, 0); err != nil {
+ t.Fatalf("failed to flatten snapshot tree: %v", err)
+ }
+
+ blob = rawdb.ReadSnapshotGenerator(db)
+
+ if err := rlp.DecodeBytes(blob, &generator); err != nil {
+ t.Fatalf("Failed to decode snapshot generator %v", err)
+ }
+
+ if len(generator.Marker) != 0 {
+ t.Fatalf("Failed to update snapshot generator")
+ }
+}
+
+// Tests that merging something into a disk layer persists it into the database
+// and invalidates any previously written and cached values, discarding anything
+// after the in-progress generation marker.
+//
+// This test case is a tiny specialized case of TestDiskPartialMerge, which tests
+// some very specific cornercases that random tests won't ever trigger.
+func TestDiskMidAccountPartialMerge(t *testing.T) {
+ // TODO(@karalabe) ?
+}
+
+// TestDiskSeek tests that seek-operations work on the disk layer
+func TestDiskSeek(t *testing.T) {
+ // Create some accounts in the disk layer
+ db := rawdb.NewMemoryDatabase()
+ defer db.Close()
+
+ // Fill even keys [0,2,4...]
+ for i := 0; i < 0xff; i += 2 {
+ acc := types.Hash{byte(i)}
+ rawdb.WriteAccountSnapshot(db, acc, acc[:])
+ }
+
+ // Add an 'higher' key, with incorrect (higher) prefix
+ highKey := []byte{rawdb.SnapshotAccountPrefix[0] + 1, rawdb.SnapshotAccountPrefix[1]}
+ db.Set(highKey, []byte{0xff, 0xff})
+
+ baseRoot := randomHash()
+ rawdb.WriteSnapshotRoot(db, baseRoot)
+
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ baseRoot: &diskLayer{
+ diskdb: db,
+ cache: fastcache.New(500 * 1024),
+ root: baseRoot,
+ logger: hclog.NewNullLogger(),
+ snapmetrics: NilMetrics(),
+ },
+ },
+ }
+
+ // Test some different seek positions
+ type testcase struct {
+ pos byte
+ expkey byte
+ }
+
+ var cases = []testcase{
+ {0xff, 0x55}, // this should exit immediately without checking key
+ {0x01, 0x02},
+ {0xfe, 0xfe},
+ {0xfd, 0xfe},
+ {0x00, 0x00},
+ }
+ for i, tc := range cases {
+ it, err := snaps.AccountIterator(baseRoot, types.Hash{tc.pos})
+ if err != nil {
+ t.Fatalf("case %d, error: %v", i, err)
+ }
+
+ count := 0
+
+ for it.Next() {
+ k, v, err := it.Hash()[0], it.Account()[0], it.Error()
+ if err != nil {
+ t.Fatalf("test %d, item %d, error: %v", i, count, err)
+ }
+ // First item in iterator should have the expected key
+ if count == 0 && k != tc.expkey {
+ t.Fatalf("test %d, item %d, got %v exp %v", i, count, k, tc.expkey)
+ }
+
+ count++
+
+ if v != k {
+ t.Fatalf("test %d, item %d, value wrong, got %v exp %v", i, count, v, k)
+ }
+ }
+ }
+}
diff --git a/state/snapshot/generate.go b/state/snapshot/generate.go
new file mode 100644
index 0000000000..f6848f5331
--- /dev/null
+++ b/state/snapshot/generate.go
@@ -0,0 +1,923 @@
+package snapshot
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "math/big"
+ "os"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/hex"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+var (
+ // accountCheckRange is the upper limit of the number of accounts involved in
+ // each range check. This is a value estimated based on experience. If this
+ // range is too large, the failure rate of range proof will increase. Otherwise,
+ // if the range is too small, the efficiency of the state recovery will decrease.
+ accountCheckRange = 128
+
+ // storageCheckRange is the upper limit of the number of storage slots involved
+ // in each range check. This is a value estimated based on experience. If this
+ // range is too large, the failure rate of range proof will increase. Otherwise,
+ // if the range is too small, the efficiency of the state recovery will decrease.
+ storageCheckRange = 1024
+
+ // errMissingTrie is returned if the target trie is missing while the generation
+ // is running. In this case the generation is aborted and wait the new signal.
+ errMissingTrie = errors.New("missing trie")
+)
+
+// generateSnapshot regenerates a brand new snapshot based on an existing state
+// database and head block asynchronously. The snapshot is returned immediately
+// and generation is continued in the background until done.
+func generateSnapshot(
+ diskdb kvdb.KVBatchStorage,
+ triedb *trie.Database,
+ cache int,
+ root types.Hash,
+ logger kvdb.Logger,
+ snapmetrics *Metrics,
+) *diskLayer {
+ // Create a new disk layer with an initialized state marker at zero
+ var (
+ stats = &generatorStats{start: time.Now(), logger: logger, generateMetrics: snapmetrics}
+ batch = diskdb.NewBatch()
+ genMarker = []byte{} // Initialized but empty!
+ )
+
+ // write batch to db and journal
+ rawdb.WriteSnapshotRoot(batch, root)
+ journalProgress(batch, genMarker, stats, logger)
+
+ if err := batch.Write(); err != nil {
+ logger.Error("Failed to write initialized state marker", "err", err)
+ os.Exit(1)
+ }
+
+ base := &diskLayer{
+ diskdb: diskdb,
+ triedb: triedb,
+ root: root,
+ cache: fastcache.New(cache * 1024 * 1024),
+ genMarker: genMarker,
+ genPending: make(chan struct{}),
+ genAbort: make(chan chan *generatorStats),
+ logger: logger,
+ snapmetrics: snapmetrics,
+ }
+
+ go base.generate(stats)
+
+ logger.Debug("Start snapshot generation", "root", root)
+
+ return base
+}
+
+// journalProgress persists the generator stats into the database to resume later.
+func journalProgress(
+ db kvdb.KVWriter,
+ marker []byte,
+ stats *generatorStats,
+ logger kvdb.Logger,
+) {
+ // Write out the generator marker. Note it's a standalone disk layer generator
+ // which is not mixed with journal. It's ok if the generator is persisted while
+ // journal is not.
+ entry := journalGenerator{
+ Done: marker == nil,
+ Marker: marker,
+ }
+
+ if stats != nil {
+ entry.Accounts = stats.accounts
+ entry.Slots = stats.slots
+ entry.Storage = uint64(stats.storage)
+ }
+
+ blob, err := rlp.EncodeToBytes(entry)
+ if err != nil {
+ panic(err) // Cannot happen, here to catch dev errors
+ }
+
+ var logstr string
+
+ switch {
+ case marker == nil:
+ logstr = "done"
+ case bytes.Equal(marker, []byte{}):
+ logstr = "empty"
+ case len(marker) == types.HashLength:
+ logstr = fmt.Sprintf("%#x", marker)
+ default:
+ logstr = fmt.Sprintf("%#x:%#x", marker[:types.HashLength], marker[types.HashLength:])
+ }
+
+ logger.Debug("Journalled generator progress", "progress", logstr)
+
+ rawdb.WriteSnapshotGenerator(db, blob)
+}
+
+// proofResult contains the output of range proving which can be used
+// for further processing regardless if it is successful or not.
+type proofResult struct {
+ keys [][]byte // The key set of all elements being iterated, even proving is failed
+ vals [][]byte // The val set of all elements being iterated, even proving is failed
+ diskMore bool // Set when the database has extra snapshot states since last iteration
+ trieMore bool // Set when the trie has extra snapshot states(only meaningful for successful proving)
+ proofErr error // Indicator whether the given state range is valid or not
+ tr *trie.Trie // The trie, in case the trie was resolved by the prover (may be nil)
+}
+
+// valid returns the indicator that range proof is successful or not.
+func (result *proofResult) valid() bool {
+ return result.proofErr == nil
+}
+
+// last returns the last verified element key regardless of whether the range proof is
+// successful or not. Nil is returned if nothing involved in the proving.
+func (result *proofResult) last() []byte {
+ var last []byte
+
+ if len(result.keys) > 0 {
+ last = result.keys[len(result.keys)-1]
+ }
+
+ return last
+}
+
+// forEach iterates all the visited elements and applies the given callback on them.
+// The iteration is aborted if the callback returns non-nil error.
+func (result *proofResult) forEach(callback func(key []byte, val []byte) error) error {
+ for i := 0; i < len(result.keys); i++ {
+ key, val := result.keys[i], result.vals[i]
+
+ if err := callback(key, val); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// generateAccounts generates the missing snapshot accounts as well as their
+// storage slots in the main trie. It's supposed to restart the generation
+// from the given origin position.
+func generateAccounts(ctx *generatorContext, dl *diskLayer, accMarker []byte, snapmetrics *Metrics) error {
+ onAccount := func(key []byte, val []byte, write bool, needDelete bool) error {
+ // Make sure to clear all dangling storages before this account
+ account := types.BytesToHash(key)
+ ctx.removeStorageBefore(account)
+
+ // starting timestamp
+ start := time.Now()
+
+ // delete acount and return
+ if needDelete {
+ rawdb.DeleteAccountSnapshot(ctx.batch, account)
+ // might take longer time than we suppose
+ ctx.removeStorageAt(account)
+ metrics.CounterInc(snapmetrics.wipedAccountCount)
+ metrics.HistogramObserve(snapmetrics.accountWriteNanoseconds, float64(time.Since(start).Nanoseconds()))
+
+ return nil
+ }
+ // Retrieve the current account and flatten it into the internal format
+ var acc struct {
+ Nonce uint64
+ Balance *big.Int
+ Root types.Hash
+ CodeHash []byte
+ }
+
+ if err := rlp.DecodeBytes(val, &acc); err != nil {
+ dl.logger.Error("Invalid account encountered during snapshot creation", "err", err)
+ os.Exit(1)
+ }
+ // If the account is not yet in-progress, write it out
+ if accMarker == nil || !bytes.Equal(account[:], accMarker) {
+ dataLen := len(val) // Approximate size, saves us a round of RLP-encoding
+
+ if !write {
+ if bytes.Equal(acc.CodeHash, types.EmptyCodeHash.Bytes()) {
+ dataLen -= 32
+ }
+
+ if acc.Root == types.EmptyRootHash {
+ dataLen -= 32
+ }
+
+ metrics.CounterInc(snapmetrics.recoveredAccountCount)
+ } else {
+ data := SlimAccountRLP(acc.Nonce, acc.Balance, acc.Root, acc.CodeHash)
+ dataLen = len(data)
+ rawdb.WriteAccountSnapshot(ctx.batch, account, data)
+ metrics.CounterInc(snapmetrics.generatedAccountCount)
+ }
+
+ ctx.stats.storage += types.StorageSize(rawdb.SnapshotPrefixLength + types.HashLength + dataLen)
+ ctx.stats.accounts++
+ }
+
+ // If the snap generation goes here after interrupted, genMarker may go backward
+ // when last genMarker is consisted of accountHash and storageHash
+ marker := account[:]
+ if accMarker != nil && bytes.Equal(marker, accMarker) && len(dl.genMarker) > types.HashLength {
+ marker = dl.genMarker[:]
+ }
+
+ // If we've exceeded our batch allowance or termination was requested, flush to disk
+ if err := dl.checkAndFlush(ctx, marker); err != nil {
+ return err
+ }
+
+ metrics.HistogramObserve(snapmetrics.accountWriteNanoseconds, float64(time.Since(start).Nanoseconds()))
+
+ // If the iterated account is the contract, create a further loop to
+ // verify or regenerate the contract storage.
+ if acc.Root == types.EmptyRootHash {
+ ctx.removeStorageAt(account)
+ } else {
+ var storeMarker []byte
+
+ if accMarker != nil && bytes.Equal(account[:], accMarker) && len(dl.genMarker) > types.HashLength {
+ storeMarker = dl.genMarker[types.HashLength:]
+ }
+
+ if err := generateStorages(ctx, dl, dl.root, account, acc.Root, storeMarker); err != nil {
+ return err
+ }
+ }
+ // Some account processed, unmark the marker
+ accMarker = nil
+
+ return nil
+ }
+
+ // Always reset the initial account range as 1 whenever recover from the
+ // interruption. TODO(rjl493456442) can we remove it?
+ var accountRange = accountCheckRange
+ if len(accMarker) > 0 {
+ accountRange = 1
+ }
+
+ origin := types.CopyBytes(accMarker)
+
+ for {
+ id := trie.StateTrieID(dl.root)
+
+ exhausted, last, err := dl.generateRange(ctx, id, rawdb.SnapshotAccountPrefix, snapAccount,
+ origin, accountRange, onAccount, FullAccountRLP)
+ if err != nil {
+ return err // The procedure it aborted, either by external signal or internal error.
+ }
+
+ origin = increaseKey(last)
+
+ // Last step, cleanup the storages after the last account.
+ // All the left storages should be treated as dangling.
+ if origin == nil || exhausted {
+ ctx.removeStorageLeft()
+
+ break
+ }
+
+ accountRange = accountCheckRange
+ }
+
+ return nil
+}
+
+// generate is a background thread that iterates over the state and storage tries,
+// constructing the state snapshot. All the arguments are purely for statistics
+// gathering and logging, since the method surfs the blocks as they arrive, often
+// being restarted.
+func (dl *diskLayer) generate(stats *generatorStats) {
+ var (
+ accMarker []byte
+ abort chan *generatorStats
+ )
+
+ if len(dl.genMarker) > 0 { // []byte{} is the start, use nil for that
+ accMarker = dl.genMarker[:types.HashLength]
+ }
+
+ stats.Log("Resuming state snapshot generation", dl.root, dl.genMarker)
+
+ // set status
+ metrics.SetGauge(stats.generateMetrics.estimateSeconds, 0)
+ metrics.SetGauge(stats.generateMetrics.usedSeconds, 0)
+
+ // Initialize the global generator context. The snapshot iterators are
+ // opened at the interrupted position because the assumption is held
+ // that all the snapshot data are generated correctly before the marker.
+ // Even if the snapshot data is updated during the interruption (before
+ // or at the marker), the assumption is still held.
+ // For the account or storage slot at the interruption, they will be
+ // processed twice by the generator(they are already processed in the
+ // last run) but it's fine.
+ ctx := newGeneratorContext(dl.snapmetrics, stats, dl.diskdb, accMarker, dl.genMarker)
+ defer ctx.close()
+
+ if err := generateAccounts(ctx, dl, accMarker, dl.snapmetrics); err != nil {
+ // Extract the received interruption signal if exists
+ var aerr = new(abortError)
+ if errors.As(err, &aerr) {
+ abort = aerr.abort
+ }
+ // Aborted by internal error, wait the signal
+ if abort == nil {
+ abort = <-dl.genAbort
+ }
+
+ abort <- stats
+
+ return
+ }
+ // Snapshot fully generated, set the marker to nil.
+ // Note even there is nothing to commit, persist the
+ // generator anyway to mark the snapshot is complete.
+ journalProgress(ctx.batch, nil, stats, dl.logger)
+
+ if err := ctx.batch.Write(); err != nil {
+ dl.logger.Error("Failed to flush batch", "err", err)
+
+ abort = <-dl.genAbort
+ abort <- stats
+
+ return
+ }
+
+ ctx.batch.Reset()
+
+ dl.logger.Info("Generated state snapshot",
+ "accounts", stats.accounts,
+ "slots", stats.slots,
+ "storage", stats.storage,
+ "dangling", stats.dangling,
+ "elapsed", types.PrettyDuration(time.Since(stats.start)),
+ )
+
+ // final metrics
+ metrics.SetGauge(stats.generateMetrics.estimateSeconds, 0)
+ metrics.SetGauge(stats.generateMetrics.usedSeconds, time.Since(stats.start).Seconds())
+
+ dl.lock.Lock()
+ dl.genMarker = nil
+ close(dl.genPending)
+ dl.lock.Unlock()
+
+ // Someone will be looking for us, wait it out
+ abort = <-dl.genAbort
+ abort <- nil
+}
+
+// proveRange proves the snapshot segment with particular prefix is "valid".
+// The iteration start point will be assigned if the iterator is restored from
+// the last interruption. Max will be assigned in order to limit the maximum
+// amount of data involved in each iteration.
+//
+// The proof result will be returned if the range proving is finished, otherwise
+// the error will be returned to abort the entire procedure.
+func (dl *diskLayer) proveRange(
+ ctx *generatorContext,
+ trieID *trie.ID,
+ prefix []byte,
+ kind string,
+ origin []byte,
+ max int,
+ valueConvertFn func([]byte) ([]byte, error),
+) (*proofResult, error) {
+ var (
+ keys [][]byte
+ vals [][]byte
+ proof = rawdb.NewMemoryDatabase()
+ diskMore = false
+ iter = ctx.iterator(kind)
+ min = append(prefix, origin...)
+ start = time.Now()
+ )
+
+ for iter.Next() {
+ // Ensure the iterated item is always equal or larger than the given origin.
+ key := iter.Key()
+ if bytes.Compare(key, min) < 0 {
+ return nil, errors.New("invalid iteration position")
+ }
+ // Ensure the iterated item still fall in the specified prefix. If
+ // not which means the items in the specified area are all visited.
+ // Move the iterator a step back since we iterate one extra element
+ // out.
+ if !bytes.Equal(key[:len(prefix)], prefix) {
+ iter.Hold()
+
+ break
+ }
+ // Break if we've reached the max size, and signal that we're not
+ // done yet. Move the iterator a step back since we iterate one
+ // extra element out.
+ if len(keys) == max {
+ iter.Hold()
+
+ diskMore = true
+
+ break
+ }
+
+ keys = append(keys, types.CopyBytes(key[len(prefix):]))
+
+ if valueConvertFn == nil {
+ vals = append(vals, types.CopyBytes(iter.Value()))
+ } else {
+ val, err := valueConvertFn(iter.Value())
+ if err != nil {
+ // Special case, the state data is corrupted (invalid slim-format account),
+ // don't abort the entire procedure directly. Instead, let the fallback
+ // generation to heal the invalid data.
+ //
+ // Here append the original value to ensure that the number of key and
+ // value are aligned.
+ vals = append(vals, types.CopyBytes(iter.Value()))
+ dl.logger.Error("Failed to convert account state data",
+ "err", err,
+ "kind", kind,
+ "prefix", hex.EncodeToHex(prefix),
+ )
+ } else {
+ vals = append(vals, val)
+ }
+ }
+ }
+
+ // Update metrics for database iteration and merkle proving
+ if kind == snapStorage {
+ metrics.HistogramObserve(dl.snapmetrics.storageSnapReadNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+ } else {
+ metrics.HistogramObserve(dl.snapmetrics.accountSnapReadNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+ }
+
+ defer func(start time.Time) {
+ if kind == snapStorage {
+ metrics.HistogramObserve(dl.snapmetrics.storageProveNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+ } else {
+ metrics.HistogramObserve(dl.snapmetrics.accountProveNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+ }
+ }(time.Now())
+
+ // The snap state is exhausted, pass the entire key/val set for verification
+ root := trieID.Root
+
+ if origin == nil && !diskMore {
+ stackTr := trie.NewStackTrie(nil)
+ for i, key := range keys {
+ stackTr.TryUpdate(key, vals[i])
+ }
+
+ if gotRoot := stackTr.Hash(); gotRoot != root {
+ return &proofResult{
+ keys: keys,
+ vals: vals,
+ proofErr: fmt.Errorf("wrong root: have %s want %s", gotRoot, root),
+ }, nil
+ }
+
+ return &proofResult{keys: keys, vals: vals}, nil
+ }
+ // Snap state is chunked, generate edge proofs for verification.
+ tr, err := trie.New(trieID, dl.triedb, dl.logger)
+ if err != nil {
+ ctx.stats.Log("Trie missing, state snapshotting paused", dl.root, dl.genMarker)
+
+ return nil, errMissingTrie
+ }
+ // Firstly find out the key of last iterated element.
+ var last []byte
+ if len(keys) > 0 {
+ last = keys[len(keys)-1]
+ }
+ // Generate the Merkle proofs for the first and last element
+ if origin == nil {
+ origin = types.Hash{}.Bytes()
+ }
+
+ if err := tr.Prove(origin, 0, proof); err != nil {
+ dl.logger.Debug("Failed to prove range", "kind", kind, "origin", origin, "err", err)
+
+ return &proofResult{
+ keys: keys,
+ vals: vals,
+ diskMore: diskMore,
+ proofErr: err,
+ tr: tr,
+ }, nil
+ }
+
+ if last != nil {
+ if err := tr.Prove(last, 0, proof); err != nil {
+ dl.logger.Debug("Failed to prove range", "kind", kind, "last", last, "err", err)
+
+ return &proofResult{
+ keys: keys,
+ vals: vals,
+ diskMore: diskMore,
+ proofErr: err,
+ tr: tr,
+ }, nil
+ }
+ }
+ // Verify the snapshot segment with range prover, ensure that all flat states
+ // in this range correspond to merkle trie.
+ cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof)
+
+ return &proofResult{
+ keys: keys,
+ vals: vals,
+ diskMore: diskMore,
+ trieMore: cont,
+ proofErr: err,
+ tr: tr},
+ nil
+}
+
+// onStateCallback is a function that is called by generateRange, when processing a range of
+// accounts or storage slots. For each element, the callback is invoked.
+//
+// - If 'delete' is true, then this element (and potential slots) needs to be deleted from the snapshot.
+// - If 'write' is true, then this element needs to be updated with the 'val'.
+// - If 'write' is false, then this element is already correct, and needs no update.
+// The 'val' is the canonical encoding of the value (not the slim format for accounts)
+//
+// However, for accounts, the storage trie of the account needs to be checked. Also,
+// dangling storages(storage exists but the corresponding account is missing) need to
+// be cleaned up.
+type onStateCallback func(key []byte, val []byte, write bool, needDelete bool) error
+
+// generateRange generates the state segment with particular prefix. Generation can
+// either verify the correctness of existing state through range-proof and skip
+// generation, or iterate trie to regenerate state on demand.
+func (dl *diskLayer) generateRange(
+ ctx *generatorContext,
+ trieID *trie.ID,
+ prefix []byte,
+ kind string,
+ origin []byte,
+ max int,
+ onState onStateCallback,
+ valueConvertFn func([]byte) ([]byte, error),
+) (bool, []byte, error) {
+ // Use range prover to check the validity of the flat state in the range
+ result, err := dl.proveRange(ctx, trieID, prefix, kind, origin, max, valueConvertFn)
+ if err != nil {
+ return false, nil, err
+ }
+
+ last := result.last()
+
+ // The range prover says the range is correct, skip trie iteration
+ if result.valid() {
+ metrics.CounterInc(dl.snapmetrics.successfulRangeProofCount)
+
+ // The verification is passed, process each state with the given
+ // callback function. If this state represents a contract, the
+ // corresponding storage check will be performed in the callback
+ if err := result.forEach(func(key []byte, val []byte) error { return onState(key, val, false, false) }); err != nil {
+ return false, nil, err
+ }
+
+ // Only abort the iteration when both database and trie are exhausted
+ return !result.diskMore && !result.trieMore, last, nil
+ }
+
+ // ctx.stats.logger.Debug("Detected outdated state range",
+ // "kind", kind,
+ // "prefix", hex.EncodeToHex(prefix),
+ // "last", hex.EncodeToHex(last),
+ // "err", result.proofErr,
+ // )
+ metrics.CounterInc(dl.snapmetrics.failedRangeProofCount)
+
+ // Special case, the entire trie is missing. In the original trie scheme,
+ // all the duplicated subtries will be filtered out (only one copy of data
+ // will be stored). While in the snapshot model, all the storage tries
+ // belong to different contracts will be kept even they are duplicated.
+ // Track it to a certain extent remove the noise data used for statistics.
+ if origin == nil && last == nil {
+ if kind == snapStorage {
+ metrics.CounterInc(dl.snapmetrics.missallStorageCount)
+ } else {
+ metrics.CounterInc(dl.snapmetrics.missallAccountCount)
+ }
+ }
+
+ // We use the snap data to build up a cache which can be used by the
+ // main account trie as a primary lookup when resolving hashes
+ var snapNodeCache kvdb.Database
+
+ if len(result.keys) > 0 {
+ snapNodeCache = rawdb.NewMemoryDatabase()
+ snapTrieDB := trie.NewDatabase(snapNodeCache, dl.logger)
+ snapTrie := trie.NewEmpty(snapTrieDB)
+
+ for i, key := range result.keys {
+ snapTrie.Update(key, result.vals[i])
+ }
+
+ root, nodes, _ := snapTrie.Commit(false)
+ if nodes != nil {
+ snapTrieDB.Update(trie.NewWithNodeSet(nodes))
+ }
+
+ snapTrieDB.Commit(root, false, nil)
+ }
+
+ // Construct the trie for state iteration, reuse the trie
+ // if it's already opened with some nodes resolved.
+ tr := result.tr
+ if tr == nil {
+ tr, err = trie.New(trieID, dl.triedb, dl.logger)
+ if err != nil {
+ ctx.stats.Log("Trie missing, state snapshotting paused", dl.root, dl.genMarker)
+
+ return false, nil, errMissingTrie
+ }
+ }
+
+ var (
+ trieMore bool
+ nodeIt = tr.NodeIterator(origin)
+ iter = trie.NewIterator(nodeIt)
+ kvkeys, kvvals = result.keys, result.vals
+
+ // counters
+ count = 0 // number of states delivered by iterator
+ created = 0 // states created from the trie
+ updated = 0 // states updated from the trie
+ deleted = 0 // states not in trie, but were in snapshot
+ untouched = 0 // states already correct
+
+ // timers
+ start = time.Now()
+ internal time.Duration
+ )
+
+ nodeIt.AddResolver(snapNodeCache)
+
+ for iter.Next() {
+ if last != nil && bytes.Compare(iter.Key, last) > 0 {
+ trieMore = true
+
+ break
+ }
+
+ count++
+ created++
+
+ write := true
+
+ for len(kvkeys) > 0 {
+ if cmp := bytes.Compare(kvkeys[0], iter.Key); cmp < 0 {
+ istart := time.Now()
+ // delete the key
+ if err := onState(kvkeys[0], nil, false, true); err != nil {
+ return false, nil, err
+ }
+
+ kvkeys = kvkeys[1:]
+ kvvals = kvvals[1:]
+ deleted++
+ // calculate internal duration
+ internal += time.Since(istart)
+
+ continue
+ } else if cmp == 0 {
+ // the snapshot key can be overwritten
+ created--
+
+ if write = !bytes.Equal(kvvals[0], iter.Value); write {
+ updated++
+ } else {
+ untouched++
+ }
+
+ kvkeys = kvkeys[1:]
+ kvvals = kvvals[1:]
+ }
+
+ break
+ }
+
+ istart := time.Now()
+ // onstate callback
+ if err := onState(iter.Key, iter.Value, write, false); err != nil {
+ return false, nil, err
+ }
+ // calculate internal duration
+ internal += time.Since(istart)
+ }
+
+ if iter.Err != nil {
+ return false, nil, iter.Err
+ }
+
+ istart := time.Now()
+ // Delete all stale snapshot states remaining
+ for _, key := range kvkeys {
+ if err := onState(key, nil, false, true); err != nil {
+ return false, nil, err
+ }
+
+ deleted += 1
+ }
+ // calculate internal duration
+ internal += time.Since(istart)
+
+ if kind == snapStorage {
+ metrics.HistogramObserve(dl.snapmetrics.storageTrieReadNanoseconds,
+ float64((time.Since(start) - internal).Nanoseconds()))
+ } else {
+ metrics.HistogramObserve(dl.snapmetrics.accountTrieReadNanoseconds,
+ float64((time.Since(start) - internal).Nanoseconds()))
+ }
+
+ dl.logger.Debug("Regenerated state range",
+ "kind", kind,
+ "prefix", hex.EncodeToHex(prefix),
+ "origin", hex.EncodeToHex(origin),
+ "root", trieID.Root,
+ "last", hex.EncodeToHex(last),
+ "count", count,
+ "created", created,
+ "updated", updated,
+ "untouched", untouched,
+ "deleted", deleted,
+ )
+
+ // If there are either more trie items, or there are more snap items
+ // (in the next segment), then we need to keep working
+ return !trieMore && !result.diskMore, last, nil
+}
+
+// checkAndFlush checks if an interruption signal is received or the
+// batch size has exceeded the allowance.
+func (dl *diskLayer) checkAndFlush(ctx *generatorContext, current []byte) error {
+ var abort chan *generatorStats
+
+ select {
+ case abort = <-dl.genAbort:
+ default:
+ }
+
+ if ctx.batch.ValueSize() > kvdb.IdealBatchSize || abort != nil {
+ if bytes.Compare(current, dl.genMarker) < 0 {
+ dl.logger.Error("Snapshot generator went backwards",
+ "current", fmt.Sprintf("%x", current),
+ "genMarker", fmt.Sprintf("%x", dl.genMarker),
+ )
+ }
+
+ // Flush out the batch anyway no matter it's empty or not.
+ // It's possible that all the states are recovered and the
+ // generation indeed makes progress.
+ journalProgress(ctx.batch, current, ctx.stats, dl.logger)
+
+ if err := ctx.batch.Write(); err != nil {
+ return err
+ }
+
+ ctx.batch.Reset()
+
+ dl.lock.Lock()
+ dl.genMarker = current
+ dl.lock.Unlock()
+
+ if abort != nil {
+ ctx.stats.Log("Aborting state snapshot generation", dl.root, current)
+
+ return newAbortError(abort) // bubble up an error for interruption
+ }
+ // Don't hold the iterators too long, release them to let compactor works
+ ctx.reopenIterator(snapAccount)
+ ctx.reopenIterator(snapStorage)
+ }
+
+ if time.Since(ctx.logged) > 8*time.Second {
+ ctx.stats.Log("Generating state snapshot", dl.root, current)
+ ctx.logged = time.Now()
+ }
+
+ return nil
+}
+
+// generateStorages generates the missing storage slots of the specific contract.
+// It's supposed to restart the generation from the given origin position.
+func generateStorages(
+ ctx *generatorContext,
+ dl *diskLayer,
+ stateRoot types.Hash,
+ account types.Hash,
+ storageRoot types.Hash,
+ storeMarker []byte,
+) error {
+ onStorage := func(key []byte, val []byte, write bool, needDelete bool) error {
+ defer func(start time.Time) {
+ metrics.HistogramObserve(dl.snapmetrics.storageWriteNanoseconds,
+ float64(time.Since(start).Nanoseconds()))
+ }(time.Now())
+
+ if needDelete {
+ rawdb.DeleteStorageSnapshot(ctx.batch, account, types.BytesToHash(key))
+ metrics.CounterInc(dl.snapmetrics.wipedStorageCount)
+
+ return nil
+ }
+
+ if write {
+ rawdb.WriteStorageSnapshot(ctx.batch, account, types.BytesToHash(key), val)
+ metrics.CounterInc(dl.snapmetrics.generatedStorageCount)
+ } else {
+ metrics.CounterInc(dl.snapmetrics.recoveredStorageCount)
+ }
+
+ ctx.stats.storage += types.StorageSize(rawdb.SnapshotPrefixLength + 2*types.HashLength + len(val))
+ ctx.stats.slots++
+
+ // If we've exceeded our batch allowance or termination was requested, flush to disk
+ if err := dl.checkAndFlush(ctx, append(account[:], key...)); err != nil {
+ return err
+ }
+
+ return nil
+ }
+
+ // Loop for re-generating the missing storage slots.
+ var origin = types.CopyBytes(storeMarker)
+
+ for {
+ id := trie.StorageTrieID(stateRoot, account, storageRoot)
+
+ exhausted, last, err := dl.generateRange(
+ ctx,
+ id,
+ append(rawdb.SnapshotStoragePrefix, account.Bytes()...),
+ snapStorage,
+ origin,
+ storageCheckRange,
+ onStorage,
+ nil,
+ )
+ if err != nil {
+ return err // The procedure it aborted, either by external signal or internal error.
+ }
+
+ // Abort the procedure if the entire contract storage is generated
+ if exhausted {
+ break
+ }
+
+ if origin = increaseKey(last); origin == nil {
+ break // special case, the last is 0xffffffff...fff
+ }
+ }
+
+ return nil
+}
+
+// increaseKey increase the input key by one bit. Return nil if the entire
+// addition operation overflows.
+func increaseKey(key []byte) []byte {
+ for i := len(key) - 1; i >= 0; i-- {
+ key[i]++
+
+ if key[i] != 0x0 {
+ return key
+ }
+ }
+
+ return nil
+}
+
+// abortError wraps an interruption signal received to represent the
+// generation is aborted by external processes.
+type abortError struct {
+ abort chan *generatorStats
+}
+
+func newAbortError(abort chan *generatorStats) error {
+ return &abortError{abort: abort}
+}
+
+func (err *abortError) Error() string {
+ return "aborted"
+}
diff --git a/state/snapshot/generate_test.go b/state/snapshot/generate_test.go
new file mode 100644
index 0000000000..1147eaa8e6
--- /dev/null
+++ b/state/snapshot/generate_test.go
@@ -0,0 +1,872 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "fmt"
+ "math/big"
+ "testing"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+ "golang.org/x/crypto/sha3"
+)
+
+func hashData(input []byte) types.Hash {
+ var hasher = sha3.NewLegacyKeccak256()
+ var hash types.Hash
+ hasher.Reset()
+ hasher.Write(input)
+ hasher.Sum(hash[:0])
+ return hash
+}
+
+// Tests that snapshot generation from an empty database.
+func TestGeneration(t *testing.T) {
+ // We can't use statedb to make a test trie (circular dependency), so make
+ // a fake one manually. We're going with a small account trie of 3 accounts,
+ // two of which also has the same 3-slot storage trie attached.
+ var helper = newHelper()
+ stRoot := helper.makeStorageTrie(types.Hash{}, types.Hash{}, []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, false)
+
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+
+ root, snap := helper.CommitAndGenerate()
+ if have, want := root, types.StringToHash("0xe3712f1a226f3782caca78ca770ccc19ee000552813a9f59d479f8611db9b1fd"); have != want {
+ t.Fatalf("have %s want %s", have, want)
+ }
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation with existent flat state.
+func TestGenerateExistentState(t *testing.T) {
+ // We can't use statedb to make a test trie (circular dependency), so make
+ // a fake one manually. We're going with a small account trie of 3 accounts,
+ // two of which also has the same 3-slot storage trie attached.
+ var helper = newHelper()
+
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-1", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+
+ stRoot = helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-3", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+func checkSnapRoot(t *testing.T, snap *diskLayer, trieRoot types.Hash) {
+ t.Helper()
+
+ logger := snap.logger
+
+ accIt := snap.AccountIterator(types.Hash{})
+ defer accIt.Release()
+
+ snapRoot, err := generateTrieRoot(nil, nil, accIt, types.Hash{}, stackTrieGenerate,
+ func(db kvdb.KVWriter, accountHash, codeHash types.Hash, stat *generateStats) (types.Hash, error) {
+ storageIt, _ := snap.StorageIterator(accountHash, types.Hash{})
+ defer storageIt.Release()
+
+ hash, err := generateTrieRoot(nil, nil, storageIt, accountHash, stackTrieGenerate, nil, stat, false)
+ if err != nil {
+ return types.Hash{}, err
+ }
+ return hash, nil
+ }, newGenerateStats(snap.logger, snap.snapmetrics), true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if snapRoot != trieRoot {
+ t.Fatalf("snaproot: %s != trieroot %s", snapRoot, trieRoot)
+ }
+ if err := CheckDanglingStorage(snap.diskdb, logger); err != nil {
+ t.Fatalf("Detected dangling storages: %v", err)
+ }
+}
+
+type testHelper struct {
+ diskdb kvdb.Database
+ triedb *trie.Database
+ accTrie *trie.StateTrie
+ nodes *trie.MergedNodeSet
+ logger kvdb.Logger
+ snapmetrics *Metrics
+}
+
+func newHelper() *testHelper {
+ diskdb := rawdb.NewMemoryDatabase()
+ logger := hclog.NewNullLogger()
+ triedb := trie.NewDatabase(diskdb, logger)
+ accTrie, _ := trie.NewStateTrie(trie.StateTrieID(types.Hash{}), triedb, logger)
+
+ return &testHelper{
+ diskdb: diskdb,
+ triedb: triedb,
+ accTrie: accTrie,
+ nodes: trie.NewMergedNodeSet(),
+ logger: logger,
+ snapmetrics: NilMetrics(),
+ }
+}
+
+func (t *testHelper) addTrieAccount(acckey string, acc *Account) {
+ val, _ := rlp.EncodeToBytes(acc)
+ t.accTrie.Update([]byte(acckey), val)
+}
+
+func (t *testHelper) addSnapAccount(acckey string, acc *Account) {
+ val, _ := rlp.EncodeToBytes(acc)
+ key := hashData([]byte(acckey))
+ rawdb.WriteAccountSnapshot(t.diskdb, key, val)
+}
+
+func (t *testHelper) addAccount(acckey string, acc *Account) {
+ t.addTrieAccount(acckey, acc)
+ t.addSnapAccount(acckey, acc)
+}
+
+func (t *testHelper) addSnapStorage(accKey string, keys []string, vals []string) {
+ accHash := hashData([]byte(accKey))
+ for i, key := range keys {
+ rawdb.WriteStorageSnapshot(t.diskdb, accHash, hashData([]byte(key)), []byte(vals[i]))
+ }
+}
+
+func (t *testHelper) makeStorageTrie(stateRoot, owner types.Hash, keys []string, vals []string, commit bool) []byte {
+ id := trie.StorageTrieID(stateRoot, owner, types.Hash{})
+ stTrie, _ := trie.NewStateTrie(id, t.triedb, t.logger)
+ for i, k := range keys {
+ stTrie.Update([]byte(k), []byte(vals[i]))
+ }
+ if !commit {
+ return stTrie.Hash().Bytes()
+ }
+ root, nodes, _ := stTrie.Commit(false)
+ if nodes != nil {
+ t.nodes.Merge(nodes)
+ }
+ return root.Bytes()
+}
+
+func (t *testHelper) Commit() types.Hash {
+ root, nodes, _ := t.accTrie.Commit(true)
+ if nodes != nil {
+ t.nodes.Merge(nodes)
+ }
+ t.triedb.Update(t.nodes)
+ t.triedb.Commit(root, false, nil)
+ return root
+}
+
+func (t *testHelper) CommitAndGenerate() (types.Hash, *diskLayer) {
+ root := t.Commit()
+ snap := generateSnapshot(t.diskdb, t.triedb, 16, root, t.logger, t.snapmetrics)
+
+ return root, snap
+}
+
+// Tests that snapshot generation with existent flat state, where the flat state
+// contains some errors:
+// - the contract with empty storage root but has storage entries in the disk
+// - the contract with non empty storage root but empty storage slots
+// - the contract(non-empty storage) misses some storage slots
+// - miss in the beginning
+// - miss in the middle
+// - miss in the end
+//
+// - the contract(non-empty storage) has wrong storage slots
+// - wrong slots in the beginning
+// - wrong slots in the middle
+// - wrong slots in the end
+//
+// - the contract(non-empty storage) has extra storage slots
+// - extra slots in the beginning
+// - extra slots in the middle
+// - extra slots in the end
+func TestGenerateExistentStateWithWrongStorage(t *testing.T) {
+ helper := newHelper()
+
+ // Account one, empty root but non-empty database
+ helper.addAccount("acc-1", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-1", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Account two, non empty root but empty database
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-2")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-2", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+
+ // Miss slots
+ {
+ // Account three, non empty root but misses slots in the beginning
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-3", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-3", []string{"key-2", "key-3"}, []string{"val-2", "val-3"})
+
+ // Account four, non empty root but misses slots in the middle
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-4")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-4", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-4", []string{"key-1", "key-3"}, []string{"val-1", "val-3"})
+
+ // Account five, non empty root but misses slots in the end
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-5")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-5", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-5", []string{"key-1", "key-2"}, []string{"val-1", "val-2"})
+ }
+
+ // Wrong storage slots
+ {
+ // Account six, non empty root but wrong slots in the beginning
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-6")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-6", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-6", []string{"key-1", "key-2", "key-3"}, []string{"badval-1", "val-2", "val-3"})
+
+ // Account seven, non empty root but wrong slots in the middle
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-7")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-7", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-7", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "badval-2", "val-3"})
+
+ // Account eight, non empty root but wrong slots in the end
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-8")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-8", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-8", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "badval-3"})
+
+ // Account 9, non empty root but rotated slots
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-9")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-9", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-9", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-3", "val-2"})
+ }
+
+ // Extra storage slots
+ {
+ // Account 10, non empty root but extra slots in the beginning
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-10")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-10", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-10", []string{"key-0", "key-1", "key-2", "key-3"}, []string{"val-0", "val-1", "val-2", "val-3"})
+
+ // Account 11, non empty root but extra slots in the middle
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-11")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-11", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-11", []string{"key-1", "key-2", "key-2-1", "key-3"}, []string{"val-1", "val-2", "val-2-1", "val-3"})
+
+ // Account 12, non empty root but extra slots in the end
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-12")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-12", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapStorage("acc-12", []string{"key-1", "key-2", "key-3", "key-4"}, []string{"val-1", "val-2", "val-3", "val-4"})
+ }
+
+ root, snap := helper.CommitAndGenerate()
+ t.Logf("Root: %s\n", root) // Root = 0x8746cce9fd9c658b2cfd639878ed6584b7a2b3e73bb40f607fcfa156002429a0
+
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation with existent flat state, where the flat state
+// contains some errors:
+// - miss accounts
+// - wrong accounts
+// - extra accounts
+func TestGenerateExistentStateWithWrongAccounts(t *testing.T) {
+ helper := newHelper()
+
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-2")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-4")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-6")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+
+ // Trie accounts [acc-1, acc-2, acc-3, acc-4, acc-6]
+ // Extra accounts [acc-0, acc-5, acc-7]
+
+ // Missing accounts, only in the trie
+ {
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // Beginning
+ helper.addTrieAccount("acc-4", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // Middle
+ helper.addTrieAccount("acc-6", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // End
+ }
+
+ // Wrong accounts
+ {
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapAccount("acc-2", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.StringToBytes("0x1234")})
+
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addSnapAccount("acc-3", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+ }
+
+ // Extra accounts, only in the snap
+ {
+ helper.addSnapAccount("acc-0", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyRootHash.Bytes()}) // before the beginning
+ helper.addSnapAccount("acc-5", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.StringToBytes("0x1234")}) // Middle
+ helper.addSnapAccount("acc-7", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyRootHash.Bytes()}) // after the end
+ }
+
+ root, snap := helper.CommitAndGenerate()
+ t.Logf("Root: %s\n", root) // Root = 0x825891472281463511e7ebcc7f109e4f9200c20fa384754e11fd605cd98464e8
+
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation errors out correctly in case of a missing trie
+// node in the account trie.
+func TestGenerateCorruptAccountTrie(t *testing.T) {
+ // We can't use statedb to make a test trie (circular dependency), so make
+ // a fake one manually. We're going with a small account trie of 3 accounts,
+ // without any storage slots to keep the test smaller.
+ helper := newHelper()
+
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}) // 0xc7a30f39aff471c95d8a837497ad0e49b65be475cc0953540f80cfcdbdcd9074
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}) // 0x19ead688e907b0fab07176120dceec244a72aff2f0aa51e8b827584e378772f4
+
+ root := helper.Commit() // Root: 0xa04693ea110a31037fb5ee814308a6f1d76bdab0b11676bdf4541d2de55ba978
+
+ // Delete an account trie leaf and ensure the generator chokes
+ helper.triedb.Commit(root, false, nil)
+ helper.diskdb.Delete(types.StringToHash("0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7").Bytes())
+
+ snap := generateSnapshot(helper.diskdb, helper.triedb, 16, root, helper.logger, helper.snapmetrics)
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+ t.Errorf("Snapshot generated against corrupt account trie")
+
+ case <-time.After(time.Second):
+ // Not generated fast enough, hopefully blocked inside on missing trie node fail
+ }
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation errors out correctly in case of a missing root
+// trie node for a storage trie. It's similar to internal corruption but it is
+// handled differently inside the generator.
+func TestGenerateMissingStorageTrie(t *testing.T) {
+ // We can't use statedb to make a test trie (circular dependency), so make
+ // a fake one manually. We're going with a small account trie of 3 accounts,
+ // two of which also has the same 3-slot storage trie attached.
+ helper := newHelper()
+
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true) // 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7
+ stRoot = helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2
+
+ root := helper.Commit()
+
+ // Delete a storage trie root and ensure the generator chokes
+ helper.diskdb.Delete(stRoot)
+
+ snap := generateSnapshot(helper.diskdb, helper.triedb, 16, root, helper.logger, helper.snapmetrics)
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+ t.Errorf("Snapshot generated against corrupt storage trie")
+
+ case <-time.After(time.Second):
+ // Not generated fast enough, hopefully blocked inside on missing trie node fail
+ }
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation errors out correctly in case of a missing trie
+// node in a storage trie.
+func TestGenerateCorruptStorageTrie(t *testing.T) {
+ // We can't use statedb to make a test trie (circular dependency), so make
+ // a fake one manually. We're going with a small account trie of 3 accounts,
+ // two of which also has the same 3-slot storage trie attached.
+ helper := newHelper()
+
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true) // 0xddefcd9376dd029653ef384bd2f0a126bb755fe84fdcc9e7cf421ba454f2bc67
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}) // 0x65145f923027566669a1ae5ccac66f945b55ff6eaeb17d2ea8e048b7d381f2d7
+ stRoot = helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}) // 0x50815097425d000edfc8b3a4a13e175fc2bdcfee8bdfbf2d1ff61041d3c235b2
+
+ root := helper.Commit()
+
+ // Delete a storage trie leaf and ensure the generator chokes
+ helper.diskdb.Delete(types.StringToHash("0x18a0f4d79cff4459642dd7604f303886ad9d77c30cf3d7d7cedb3a693ab6d371").Bytes())
+
+ snap := generateSnapshot(helper.diskdb, helper.triedb, 16, root, helper.logger, helper.snapmetrics)
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+ t.Errorf("Snapshot generated against corrupt storage trie")
+
+ case <-time.After(time.Second):
+ // Not generated fast enough, hopefully blocked inside on missing trie node fail
+ }
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation when an extra account with storage exists in the snap state.
+func TestGenerateWithExtraAccounts(t *testing.T) {
+ helper := newHelper()
+ {
+ // Account one in the trie
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")),
+ []string{"key-1", "key-2", "key-3", "key-4", "key-5"},
+ []string{"val-1", "val-2", "val-3", "val-4", "val-5"},
+ true,
+ )
+ acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
+
+ // Identical in the snap
+ key := hashData([]byte("acc-1"))
+ rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), []byte("val-1"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), []byte("val-2"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), []byte("val-3"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-4")), []byte("val-4"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-5")), []byte("val-5"))
+ }
+ {
+ // Account two exists only in the snapshot
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-2")),
+ []string{"key-1", "key-2", "key-3", "key-4", "key-5"},
+ []string{"val-1", "val-2", "val-3", "val-4", "val-5"},
+ true,
+ )
+ acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ key := hashData([]byte("acc-2"))
+ rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-1")), []byte("b-val-1"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-2")), []byte("b-val-2"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("b-key-3")), []byte("b-val-3"))
+ }
+ root := helper.Commit()
+
+ // To verify the test: If we now inspect the snap db, there should exist extraneous storage items
+ if data := rawdb.ReadStorageSnapshot(helper.diskdb, hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data == nil {
+ t.Fatalf("expected snap storage to exist")
+ }
+ snap := generateSnapshot(helper.diskdb, helper.triedb, 16, root, helper.logger, helper.snapmetrics)
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+ // If we now inspect the snap db, there should exist no extraneous storage items
+ if data := rawdb.ReadStorageSnapshot(helper.diskdb, hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data != nil {
+ t.Fatalf("expected slot to be removed, got %v", string(data))
+ }
+}
+
+// Tests that snapshot generation when an extra account with storage exists in the snap state.
+func TestGenerateWithManyExtraAccounts(t *testing.T) {
+ helper := newHelper()
+ {
+ // Account one in the trie
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")),
+ []string{"key-1", "key-2", "key-3"},
+ []string{"val-1", "val-2", "val-3"},
+ true,
+ )
+ acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
+
+ // Identical in the snap
+ key := hashData([]byte("acc-1"))
+ rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-1")), []byte("val-1"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-2")), []byte("val-2"))
+ rawdb.WriteStorageSnapshot(helper.diskdb, key, hashData([]byte("key-3")), []byte("val-3"))
+ }
+ {
+ // 100 accounts exist only in snapshot
+ for i := 0; i < 1000; i++ {
+ //acc := &Account{Balance: big.NewInt(int64(i)), Root: stTrie.Hash().Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
+ acc := &Account{Balance: big.NewInt(int64(i)), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ key := hashData([]byte(fmt.Sprintf("acc-%d", i)))
+ rawdb.WriteAccountSnapshot(helper.diskdb, key, val)
+ }
+ }
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests this case
+// maxAccountRange 3
+// snapshot-accounts: 01, 02, 03, 04, 05, 06, 07
+// trie-accounts: 03, 07
+//
+// We iterate three snapshot storage slots (max = 3) from the database. They are 0x01, 0x02, 0x03.
+// The trie has a lot of deletions.
+// So in trie, we iterate 2 entries 0x03, 0x07. We create the 0x07 in the database and abort the procedure, because the trie is exhausted.
+// But in the database, we still have the stale storage slots 0x04, 0x05. They are not iterated yet, but the procedure is finished.
+func TestGenerateWithExtraBeforeAndAfter(t *testing.T) {
+ accountCheckRange = 3
+
+ helper := newHelper()
+ {
+ acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ helper.accTrie.Update(types.StringToHash("0x03").Bytes(), val)
+ helper.accTrie.Update(types.StringToHash("0x07").Bytes(), val)
+
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x01"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x02"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x03"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x04"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x05"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x06"), val)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x07"), val)
+ }
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// TestGenerateWithMalformedSnapdata tests what happes if we have some junk
+// in the snapshot database, which cannot be parsed back to an account
+func TestGenerateWithMalformedSnapdata(t *testing.T) {
+ accountCheckRange = 3
+ helper := newHelper()
+ {
+ acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
+ val, _ := rlp.EncodeToBytes(acc)
+ helper.accTrie.Update(types.StringToHash("0x03").Bytes(), val)
+
+ junk := make([]byte, 100)
+ copy(junk, []byte{0xde, 0xad})
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x02"), junk)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x03"), junk)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x04"), junk)
+ rawdb.WriteAccountSnapshot(helper.diskdb, types.StringToHash("0x05"), junk)
+ }
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+ // If we now inspect the snap db, there should exist no extraneous storage items
+ if data := rawdb.ReadStorageSnapshot(helper.diskdb, hashData([]byte("acc-2")), hashData([]byte("b-key-1"))); data != nil {
+ t.Fatalf("expected slot to be removed, got %v", string(data))
+ }
+}
+
+func TestGenerateFromEmptySnap(t *testing.T) {
+ //enableLogging()
+ accountCheckRange = 10
+ storageCheckRange = 20
+ helper := newHelper()
+ // Add 1K accounts to the trie
+ for i := 0; i < 400; i++ {
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte(fmt.Sprintf("acc-%d", i))), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount(fmt.Sprintf("acc-%d", i),
+ &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ }
+ root, snap := helper.CommitAndGenerate()
+ t.Logf("Root: %s\n", root) // Root: 0x6f7af6d2e1a1bf2b84a3beb3f8b64388465fbc1e274ca5d5d3fc787ca78f59e4
+
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation with existent flat state, where the flat state
+// storage is correct, but incomplete.
+// The incomplete part is on the second range
+// snap: [ 0x01, 0x02, 0x03, 0x04] , [ 0x05, 0x06, 0x07, {missing}] (with storageCheck = 4)
+// trie: 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08
+// This hits a case where the snap verification passes, but there are more elements in the trie
+// which we must also add.
+func TestGenerateWithIncompleteStorage(t *testing.T) {
+ storageCheckRange = 4
+ helper := newHelper()
+ stKeys := []string{"1", "2", "3", "4", "5", "6", "7", "8"}
+ stVals := []string{"v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8"}
+ // We add 8 accounts, each one is missing exactly one of the storage slots. This means
+ // we don't have to order the keys and figure out exactly which hash-key winds up
+ // on the sensitive spots at the boundaries
+ for i := 0; i < 8; i++ {
+ accKey := fmt.Sprintf("acc-%d", i)
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte(accKey)), stKeys, stVals, true)
+ helper.addAccount(accKey, &Account{Balance: big.NewInt(int64(i)), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ var moddedKeys []string
+ var moddedVals []string
+ for ii := 0; ii < 8; ii++ {
+ if ii != i {
+ moddedKeys = append(moddedKeys, stKeys[ii])
+ moddedVals = append(moddedVals, stVals[ii])
+ }
+ }
+ helper.addSnapStorage(accKey, moddedKeys, moddedVals)
+ }
+ root, snap := helper.CommitAndGenerate()
+ t.Logf("Root: %s\n", root) // Root: 0xca73f6f05ba4ca3024ef340ef3dfca8fdabc1b677ff13f5a9571fd49c16e67ff
+
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+func incKey(key []byte) []byte {
+ for i := len(key) - 1; i >= 0; i-- {
+ key[i]++
+ if key[i] != 0x0 {
+ break
+ }
+ }
+ return key
+}
+
+func decKey(key []byte) []byte {
+ for i := len(key) - 1; i >= 0; i-- {
+ key[i]--
+ if key[i] != 0xff {
+ break
+ }
+ }
+ return key
+}
+
+func populateDangling(disk kvdb.KVBatchStorage) {
+ populate := func(accountHash types.Hash, keys []string, vals []string) {
+ for i, key := range keys {
+ rawdb.WriteStorageSnapshot(disk, accountHash, hashData([]byte(key)), []byte(vals[i]))
+ }
+ }
+ // Dangling storages of the "first" account
+ populate(types.Hash{}, []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Dangling storages of the "last" account
+ populate(types.StringToHash("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Dangling storages around the account 1
+ hash := decKey(hashData([]byte("acc-1")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ hash = incKey(hashData([]byte("acc-1")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Dangling storages around the account 2
+ hash = decKey(hashData([]byte("acc-2")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ hash = incKey(hashData([]byte("acc-2")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Dangling storages around the account 3
+ hash = decKey(hashData([]byte("acc-3")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ hash = incKey(hashData([]byte("acc-3")).Bytes())
+ populate(types.BytesToHash(hash), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ // Dangling storages of the random account
+ populate(randomHash(), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ populate(randomHash(), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ populate(randomHash(), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+}
+
+// Tests that snapshot generation with dangling storages. Dangling storage means
+// the storage data is existent while the corresponding account data is missing.
+//
+// This test will populate some dangling storages to see if they can be cleaned up.
+func TestGenerateCompleteSnapshotWithDanglingStorage(t *testing.T) {
+ var helper = newHelper()
+
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addAccount("acc-2", &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addAccount("acc-3", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+
+ helper.addSnapStorage("acc-1", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+ helper.addSnapStorage("acc-3", []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"})
+
+ populateDangling(helper.diskdb)
+
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
+
+// Tests that snapshot generation with dangling storages. Dangling storage means
+// the storage data is existent while the corresponding account data is missing.
+//
+// This test will populate some dangling storages to see if they can be cleaned up.
+func TestGenerateBrokenSnapshotWithDanglingStorage(t *testing.T) {
+ var helper = newHelper()
+
+ stRoot := helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-1")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-1", &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+ helper.addTrieAccount("acc-2", &Account{Balance: big.NewInt(2), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()})
+
+ helper.makeStorageTrie(types.Hash{}, hashData([]byte("acc-3")), []string{"key-1", "key-2", "key-3"}, []string{"val-1", "val-2", "val-3"}, true)
+ helper.addTrieAccount("acc-3", &Account{Balance: big.NewInt(3), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()})
+
+ populateDangling(helper.diskdb)
+
+ root, snap := helper.CommitAndGenerate()
+ select {
+ case <-snap.genPending:
+ // Snapshot generation succeeded
+
+ case <-time.After(3 * time.Second):
+ t.Errorf("Snapshot generation failed")
+ }
+ checkSnapRoot(t, snap, root)
+
+ // Signal abortion to the generator and wait for it to tear down
+ stop := make(chan *generatorStats)
+ snap.genAbort <- stop
+ <-stop
+}
diff --git a/state/snapshot/holdable_iterator.go b/state/snapshot/holdable_iterator.go
new file mode 100644
index 0000000000..5f1b02091c
--- /dev/null
+++ b/state/snapshot/holdable_iterator.go
@@ -0,0 +1,102 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// holdableIterator is a wrapper of underlying database iterator. It extends
+// the basic iterator interface by adding Hold which can hold the element
+// locally where the iterator is currently located and serve it up next time.
+type holdableIterator struct {
+ it kvdb.Iterator
+ key []byte
+ val []byte
+ atHeld bool
+}
+
+// newHoldableIterator initializes the holdableIterator with the given iterator.
+func newHoldableIterator(it kvdb.Iterator) *holdableIterator {
+ return &holdableIterator{it: it}
+}
+
+// Hold holds the element locally where the iterator is currently located which
+// can be served up next time.
+func (it *holdableIterator) Hold() {
+ if it.it.Key() == nil {
+ return // nothing to hold
+ }
+
+ it.key = types.CopyBytes(it.it.Key())
+ it.val = types.CopyBytes(it.it.Value())
+ it.atHeld = false
+}
+
+// Next moves the iterator to the next key/value pair. It returns whether the
+// iterator is exhausted.
+func (it *holdableIterator) Next() bool {
+ if !it.atHeld && it.key != nil {
+ it.atHeld = true
+ } else if it.atHeld {
+ it.atHeld = false
+ it.key = nil
+ it.val = nil
+ }
+
+ if it.key != nil {
+ return true // shifted to locally held value
+ }
+
+ return it.it.Next()
+}
+
+// Error returns any accumulated error. Exhausting all the key/value pairs
+// is not considered to be an error.
+func (it *holdableIterator) Error() error { return it.it.Error() }
+
+// Release releases associated resources. Release should always succeed and can
+// be called multiple times without causing error.
+func (it *holdableIterator) Release() {
+ it.atHeld = false
+ it.key = nil
+ it.val = nil
+ it.it.Release()
+}
+
+// Key returns the key of the current key/value pair, or nil if done. The caller
+// should not modify the contents of the returned slice, and its contents may
+// change on the next call to Next.
+func (it *holdableIterator) Key() []byte {
+ if it.key != nil {
+ return it.key
+ }
+
+ return it.it.Key()
+}
+
+// Value returns the value of the current key/value pair, or nil if done. The
+// caller should not modify the contents of the returned slice, and its contents
+// may change on the next call to Next.
+func (it *holdableIterator) Value() []byte {
+ if it.val != nil {
+ return it.val
+ }
+
+ return it.it.Value()
+}
diff --git a/state/snapshot/holdable_iterator_test.go b/state/snapshot/holdable_iterator_test.go
new file mode 100644
index 0000000000..f5a776cec7
--- /dev/null
+++ b/state/snapshot/holdable_iterator_test.go
@@ -0,0 +1,180 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+func TestIteratorHold(t *testing.T) {
+ // Create the key-value data store
+ var (
+ content = map[string]string{"k1": "v1", "k2": "v2", "k3": "v3"}
+ order = []string{"k1", "k2", "k3"}
+ db = rawdb.NewMemoryDatabase()
+ )
+
+ for key, val := range content {
+ if err := db.Set([]byte(key), []byte(val)); err != nil {
+ t.Fatalf("failed to insert item %s:%s into database: %v", key, val, err)
+ }
+ }
+
+ // Iterate over the database with the given configs and verify the results
+ it, idx := newHoldableIterator(db.NewIterator(nil, nil)), 0
+
+ // Nothing should be affected for calling Discard on non-initialized iterator
+ it.Hold()
+
+ for it.Next() {
+ if len(content) <= idx {
+ t.Errorf("more items than expected: checking idx=%d (key %q), expecting len=%d", idx, it.Key(), len(order))
+ break
+ }
+
+ if !bytes.Equal(it.Key(), []byte(order[idx])) {
+ t.Errorf("item %d: key mismatch: have %s, want %s", idx, string(it.Key()), order[idx])
+ }
+
+ if !bytes.Equal(it.Value(), []byte(content[order[idx]])) {
+ t.Errorf("item %d: value mismatch: have %s, want %s", idx, string(it.Value()), content[order[idx]])
+ }
+
+ // Should be safe to call discard multiple times
+ it.Hold()
+ it.Hold()
+
+ // Shift iterator to the discarded element
+ it.Next()
+
+ if !bytes.Equal(it.Key(), []byte(order[idx])) {
+ t.Errorf("item %d: key mismatch: have %s, want %s", idx, string(it.Key()), order[idx])
+ }
+
+ if !bytes.Equal(it.Value(), []byte(content[order[idx]])) {
+ t.Errorf("item %d: value mismatch: have %s, want %s", idx, string(it.Value()), content[order[idx]])
+ }
+
+ // Discard/Next combo should work always
+ it.Hold()
+ it.Next()
+
+ if !bytes.Equal(it.Key(), []byte(order[idx])) {
+ t.Errorf("item %d: key mismatch: have %s, want %s", idx, string(it.Key()), order[idx])
+ }
+
+ if !bytes.Equal(it.Value(), []byte(content[order[idx]])) {
+ t.Errorf("item %d: value mismatch: have %s, want %s", idx, string(it.Value()), content[order[idx]])
+ }
+
+ idx++
+ }
+
+ if err := it.Error(); err != nil {
+ t.Errorf("iteration failed: %v", err)
+ }
+
+ if idx != len(order) {
+ t.Errorf("iteration terminated prematurely: have %d, want %d", idx, len(order))
+ }
+
+ db.Close()
+}
+
+func TestReopenIterator(t *testing.T) {
+ var (
+ content = map[types.Hash]string{
+ types.StringToHash("a1"): "v1",
+ types.StringToHash("a2"): "v2",
+ types.StringToHash("a3"): "v3",
+ types.StringToHash("a4"): "v4",
+ types.StringToHash("a5"): "v5",
+ types.StringToHash("a6"): "v6",
+ }
+ order = []types.Hash{
+ types.StringToHash("a1"),
+ types.StringToHash("a2"),
+ types.StringToHash("a3"),
+ types.StringToHash("a4"),
+ types.StringToHash("a5"),
+ types.StringToHash("a6"),
+ }
+ db = rawdb.NewMemoryDatabase()
+ )
+
+ for key, val := range content {
+ rawdb.WriteAccountSnapshot(db, key, []byte(val))
+ }
+
+ checkVal := func(it *holdableIterator, index int) {
+ if !bytes.Equal(it.Key(), append(rawdb.SnapshotAccountPrefix, order[index].Bytes()...)) {
+ t.Fatalf("Unexpected data entry key, want %v got %v", order[index], it.Key())
+ }
+
+ if !bytes.Equal(it.Value(), []byte(content[order[index]])) {
+ t.Fatalf("Unexpected data entry key, want %v got %v", []byte(content[order[index]]), it.Value())
+ }
+ }
+
+ // Iterate over the database with the given configs and verify the results
+ ctx, idx := newGeneratorContext(NilMetrics(), &generatorStats{}, db, nil, nil), -1
+
+ idx++
+ ctx.account.Next()
+ checkVal(ctx.account, idx)
+
+ ctx.reopenIterator(snapAccount)
+ idx++
+ ctx.account.Next()
+ checkVal(ctx.account, idx)
+
+ // reopen twice
+ ctx.reopenIterator(snapAccount)
+ ctx.reopenIterator(snapAccount)
+ idx++
+ ctx.account.Next()
+ checkVal(ctx.account, idx)
+
+ // reopen iterator with held value
+ ctx.account.Next()
+ ctx.account.Hold()
+ ctx.reopenIterator(snapAccount)
+ idx++
+ ctx.account.Next()
+ checkVal(ctx.account, idx)
+
+ // reopen twice iterator with held value
+ ctx.account.Next()
+ ctx.account.Hold()
+ ctx.reopenIterator(snapAccount)
+ ctx.reopenIterator(snapAccount)
+ idx++
+ ctx.account.Next()
+ checkVal(ctx.account, idx)
+
+ // shift to the end and reopen
+ ctx.account.Next() // the end
+ ctx.reopenIterator(snapAccount)
+ ctx.account.Next()
+ if ctx.account.Key() != nil {
+ t.Fatal("Unexpected iterated entry")
+ }
+}
diff --git a/state/snapshot/iterator.go b/state/snapshot/iterator.go
new file mode 100644
index 0000000000..8f411a3c1e
--- /dev/null
+++ b/state/snapshot/iterator.go
@@ -0,0 +1,438 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Iterator is an iterator to step over all the accounts or the specific
+// storage in a snapshot which may or may not be composed of multiple layers.
+type Iterator interface {
+ // Next steps the iterator forward one element, returning false if exhausted,
+ // or an error if iteration failed for some reason (e.g. root being iterated
+ // becomes stale and garbage collected).
+ Next() bool
+
+ // Error returns any failure that occurred during iteration, which might have
+ // caused a premature iteration exit (e.g. snapshot stack becoming stale).
+ Error() error
+
+ // Hash returns the hash of the account or storage slot the iterator is
+ // currently at.
+ Hash() types.Hash
+
+ // Release releases associated resources. Release should always succeed and
+ // can be called multiple times without causing error.
+ Release()
+}
+
+// AccountIterator is an iterator to step over all the accounts in a snapshot,
+// which may or may not be composed of multiple layers.
+type AccountIterator interface {
+ Iterator
+
+ // Account returns the RLP encoded slim account the iterator is currently at.
+ // An error will be returned if the iterator becomes invalid
+ Account() []byte
+}
+
+// StorageIterator is an iterator to step over the specific storage in a snapshot,
+// which may or may not be composed of multiple layers.
+type StorageIterator interface {
+ Iterator
+
+ // Slot returns the storage slot the iterator is currently at. An error will
+ // be returned if the iterator becomes invalid
+ Slot() []byte
+}
+
+// diffAccountIterator is an account iterator that steps over the accounts (both
+// live and deleted) contained within a single diff layer. Higher order iterators
+// will use the deleted accounts to skip deeper iterators.
+type diffAccountIterator struct {
+ // curHash is the current hash the iterator is positioned on. The field is
+ // explicitly tracked since the referenced diff layer might go stale after
+ // the iterator was positioned and we don't want to fail accessing the old
+ // hash as long as the iterator is not touched any more.
+ curHash types.Hash
+
+ layer *diffLayer // Live layer to retrieve values from
+ keys []types.Hash // Keys left in the layer to iterate
+ fail error // Any failures encountered (stale)
+}
+
+// AccountIterator creates an account iterator over a single diff layer.
+func (dl *diffLayer) AccountIterator(seek types.Hash) AccountIterator {
+ // Seek out the requested starting account
+ hashes := dl.AccountList()
+ index := sort.Search(len(hashes), func(i int) bool {
+ return bytes.Compare(seek[:], hashes[i][:]) <= 0
+ })
+
+ // Assemble and returned the already seeked iterator
+ return &diffAccountIterator{
+ layer: dl,
+ keys: hashes[index:],
+ }
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diffAccountIterator) Next() bool {
+ // If the iterator was already stale, consider it a programmer error. Although
+ // we could just return false here, triggering this path would probably mean
+ // somebody forgot to check for Error, so lets blow up instead of undefined
+ // behavior that's hard to debug.
+ if it.fail != nil {
+ panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail))
+ }
+
+ // Stop iterating if all keys were exhausted
+ if len(it.keys) == 0 {
+ return false
+ }
+
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+
+ return false
+ }
+
+ // Iterator seems to be still alive, retrieve and cache the live hash
+ it.curHash = it.keys[0]
+ // key cached, shift the iterator and notify the user of success
+ it.keys = it.keys[1:]
+
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (it *diffAccountIterator) Error() error {
+ return it.fail
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diffAccountIterator) Hash() types.Hash {
+ return it.curHash
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at.
+// This method may _fail_, if the underlying layer has been flattened between
+// the call to Next and Account. That type of error will set it.Err.
+// This method assumes that flattening does not delete elements from
+// the accountdata mapping (writing nil into it is fine though), and will panic
+// if elements have been deleted.
+//
+// Note the returned account is not a copy, please don't modify it.
+func (it *diffAccountIterator) Account() []byte {
+ it.layer.lock.RLock()
+
+ blob, ok := it.layer.accountData[it.curHash]
+ if !ok {
+ if _, ok := it.layer.destructSet[it.curHash]; ok {
+ it.layer.lock.RUnlock()
+
+ return nil
+ }
+
+ panic(fmt.Sprintf("iterator referenced non-existent account: %x", it.curHash))
+ }
+
+ it.layer.lock.RUnlock()
+
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+ }
+
+ return blob
+}
+
+// Release is a noop for diff account iterators as there are no held resources.
+func (it *diffAccountIterator) Release() {}
+
+// diskAccountIterator is an account iterator that steps over the live accounts
+// contained within a disk layer.
+type diskAccountIterator struct {
+ layer *diskLayer
+ it kvdb.Iterator
+}
+
+// AccountIterator creates an account iterator over a disk layer.
+func (dl *diskLayer) AccountIterator(seek types.Hash) AccountIterator {
+ pos := types.TrimRightZeroes(seek[:])
+
+ return &diskAccountIterator{
+ layer: dl,
+ it: dl.diskdb.NewIterator(rawdb.SnapshotAccountPrefix, pos),
+ }
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diskAccountIterator) Next() bool {
+ // If the iterator was already exhausted, don't bother
+ if it.it == nil {
+ return false
+ }
+
+ // Try to advance the iterator and release it if we reached the end
+ for {
+ if !it.it.Next() {
+ it.it.Release()
+ it.it = nil
+
+ return false
+ }
+
+ key := it.it.Key()
+ // strict match
+ if len(key) == rawdb.SnapshotPrefixLength+types.HashLength &&
+ bytes.Equal(key[:rawdb.SnapshotPrefixLength], rawdb.SnapshotAccountPrefix) {
+ break
+ }
+ }
+
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+//
+// A diff layer is immutable after creation content wise and can always be fully
+// iterated without error, so this method always returns nil.
+func (it *diskAccountIterator) Error() error {
+ if it.it == nil {
+ return nil // Iterator is exhausted and released
+ }
+
+ return it.it.Error()
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diskAccountIterator) Hash() types.Hash {
+ return types.BytesToHash(it.it.Key()) // The prefix will be truncated
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at.
+func (it *diskAccountIterator) Account() []byte {
+ return it.it.Value()
+}
+
+// Release releases the database snapshot held during iteration.
+func (it *diskAccountIterator) Release() {
+ // The iterator is auto-released on exhaustion, so make sure it's still alive
+ if it.it != nil {
+ it.it.Release()
+ it.it = nil
+ }
+}
+
+// diffStorageIterator is a storage iterator that steps over the specific storage
+// (both live and deleted) contained within a single diff layer. Higher order
+// iterators will use the deleted slot to skip deeper iterators.
+type diffStorageIterator struct {
+ // curHash is the current hash the iterator is positioned on. The field is
+ // explicitly tracked since the referenced diff layer might go stale after
+ // the iterator was positioned and we don't want to fail accessing the old
+ // hash as long as the iterator is not touched any more.
+ curHash types.Hash
+ account types.Hash
+
+ layer *diffLayer // Live layer to retrieve values from
+ keys []types.Hash // Keys left in the layer to iterate
+ fail error // Any failures encountered (stale)
+}
+
+// StorageIterator creates a storage iterator over a single diff layer.
+// Except the storage iterator is returned, there is an additional flag
+// "destructed" returned. If it's true then it means the whole storage is
+// destructed in this layer(maybe recreated too), don't bother deeper layer
+// for storage retrieval.
+func (dl *diffLayer) StorageIterator(account types.Hash, seek types.Hash) (StorageIterator, bool) {
+ // Create the storage for this account even it's marked
+ // as destructed. The iterator is for the new one which
+ // just has the same address as the deleted one.
+ hashes, destructed := dl.StorageList(account)
+ index := sort.Search(len(hashes), func(i int) bool {
+ return bytes.Compare(seek[:], hashes[i][:]) <= 0
+ })
+
+ // Assemble and returned the already seeked iterator
+ return &diffStorageIterator{
+ layer: dl,
+ account: account,
+ keys: hashes[index:],
+ }, destructed
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diffStorageIterator) Next() bool {
+ // If the iterator was already stale, consider it a programmer error. Although
+ // we could just return false here, triggering this path would probably mean
+ // somebody forgot to check for Error, so lets blow up instead of undefined
+ // behavior that's hard to debug.
+ if it.fail != nil {
+ panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail))
+ }
+ // Stop iterating if all keys were exhausted
+ if len(it.keys) == 0 {
+ return false
+ }
+
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+
+ return false
+ }
+
+ // Iterator seems to be still alive, retrieve and cache the live hash
+ it.curHash = it.keys[0]
+ // key cached, shift the iterator and notify the user of success
+ it.keys = it.keys[1:]
+
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (it *diffStorageIterator) Error() error {
+ return it.fail
+}
+
+// Hash returns the hash of the storage slot the iterator is currently at.
+func (it *diffStorageIterator) Hash() types.Hash {
+ return it.curHash
+}
+
+// Slot returns the raw storage slot value the iterator is currently at.
+// This method may _fail_, if the underlying layer has been flattened between
+// the call to Next and Value. That type of error will set it.Err.
+// This method assumes that flattening does not delete elements from
+// the storage mapping (writing nil into it is fine though), and will panic
+// if elements have been deleted.
+//
+// Note the returned slot is not a copy, please don't modify it.
+func (it *diffStorageIterator) Slot() []byte {
+ it.layer.lock.RLock()
+
+ storage, ok := it.layer.storageData[it.account]
+ if !ok {
+ panic(fmt.Sprintf("iterator referenced non-existent account storage: %x", it.account))
+ }
+ // Storage slot might be nil(deleted), but it must exist
+ blob, ok := storage[it.curHash]
+ if !ok {
+ panic(fmt.Sprintf("iterator referenced non-existent storage slot: %x", it.curHash))
+ }
+
+ it.layer.lock.RUnlock()
+
+ if it.layer.Stale() {
+ it.fail, it.keys = ErrSnapshotStale, nil
+ }
+
+ return blob
+}
+
+// Release is a noop for diff account iterators as there are no held resources.
+func (it *diffStorageIterator) Release() {}
+
+// diskStorageIterator is a storage iterator that steps over the live storage
+// contained within a disk layer.
+type diskStorageIterator struct {
+ layer *diskLayer
+ account types.Hash
+ it kvdb.Iterator
+}
+
+// StorageIterator creates a storage iterator over a disk layer.
+// If the whole storage is destructed, then all entries in the disk
+// layer are deleted already. So the "destructed" flag returned here
+// is always false.
+func (dl *diskLayer) StorageIterator(account types.Hash, seek types.Hash) (StorageIterator, bool) {
+ pos := types.TrimRightZeroes(seek[:])
+
+ return &diskStorageIterator{
+ layer: dl,
+ account: account,
+ it: dl.diskdb.NewIterator(rawdb.SnapshotsStorageKey(account), pos),
+ }, false
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diskStorageIterator) Next() bool {
+ // If the iterator was already exhausted, don't bother
+ if it.it == nil {
+ return false
+ }
+ // Try to advance the iterator and release it if we reached the end
+ for {
+ if !it.it.Next() {
+ it.it.Release()
+ it.it = nil
+
+ return false
+ }
+
+ key := it.it.Key()
+ // strict match
+ if (len(key) == rawdb.SnapshotPrefixLength+types.HashLength+types.HashLength) &&
+ bytes.Equal(key[:rawdb.SnapshotPrefixLength], rawdb.SnapshotStoragePrefix) {
+ break
+ }
+ }
+
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+//
+// A diff layer is immutable after creation content wise and can always be fully
+// iterated without error, so this method always returns nil.
+func (it *diskStorageIterator) Error() error {
+ if it.it == nil {
+ return nil // Iterator is exhausted and released
+ }
+
+ return it.it.Error()
+}
+
+// Hash returns the hash of the storage slot the iterator is currently at.
+func (it *diskStorageIterator) Hash() types.Hash {
+ return types.BytesToHash(it.it.Key()) // The prefix will be truncated
+}
+
+// Slot returns the raw storage slot content the iterator is currently at.
+func (it *diskStorageIterator) Slot() []byte {
+ return it.it.Value()
+}
+
+// Release releases the database snapshot held during iteration.
+func (it *diskStorageIterator) Release() {
+ // The iterator is auto-released on exhaustion, so make sure it's still alive
+ if it.it != nil {
+ it.it.Release()
+ it.it = nil
+ }
+}
diff --git a/state/snapshot/iterator_binary.go b/state/snapshot/iterator_binary.go
new file mode 100644
index 0000000000..c093cb46d0
--- /dev/null
+++ b/state/snapshot/iterator_binary.go
@@ -0,0 +1,247 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// binaryIterator is a simplistic iterator to step over the accounts or storage
+// in a snapshot, which may or may not be composed of multiple layers. Performance
+// wise this iterator is slow, it's meant for cross validating the fast one,
+type binaryIterator struct {
+ a Iterator
+ b Iterator
+ aDone bool
+ bDone bool
+ accountIterator bool
+ k types.Hash
+ account types.Hash
+ fail error
+}
+
+// initBinaryAccountIterator creates a simplistic iterator to step over all the
+// accounts in a slow, but easily verifiable way. Note this function is used for
+// initialization, use `newBinaryAccountIterator` as the API.
+func (dl *diffLayer) initBinaryAccountIterator() Iterator {
+ parent, ok := dl.parent.(*diffLayer)
+ if !ok {
+ l := &binaryIterator{
+ a: dl.AccountIterator(types.Hash{}),
+ b: dl.Parent().AccountIterator(types.Hash{}),
+ accountIterator: true,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = !l.b.Next()
+
+ return l
+ }
+
+ l := &binaryIterator{
+ a: dl.AccountIterator(types.Hash{}),
+ b: parent.initBinaryAccountIterator(),
+ accountIterator: true,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = !l.b.Next()
+
+ return l
+}
+
+// initBinaryStorageIterator creates a simplistic iterator to step over all the
+// storage slots in a slow, but easily verifiable way. Note this function is used
+// for initialization, use `newBinaryStorageIterator` as the API.
+func (dl *diffLayer) initBinaryStorageIterator(account types.Hash) Iterator {
+ parent, ok := dl.parent.(*diffLayer)
+ if !ok {
+ // If the storage in this layer is already destructed, discard all
+ // deeper layers but still return an valid single-branch iterator.
+ a, destructed := dl.StorageIterator(account, types.Hash{})
+ if destructed {
+ l := &binaryIterator{
+ a: a,
+ account: account,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = true
+
+ return l
+ }
+
+ // The parent is disk layer, don't need to take care "destructed"
+ // anymore.
+ b, _ := dl.Parent().StorageIterator(account, types.Hash{})
+ l := &binaryIterator{
+ a: a,
+ b: b,
+ account: account,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = !l.b.Next()
+
+ return l
+ }
+
+ // If the storage in this layer is already destructed, discard all
+ // deeper layers but still return an valid single-branch iterator.
+ a, destructed := dl.StorageIterator(account, types.Hash{})
+ if destructed {
+ l := &binaryIterator{
+ a: a,
+ account: account,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = true
+
+ return l
+ }
+
+ l := &binaryIterator{
+ a: a,
+ b: parent.initBinaryStorageIterator(account),
+ account: account,
+ }
+
+ l.aDone = !l.a.Next()
+ l.bDone = !l.b.Next()
+
+ return l
+}
+
+// Next steps the iterator forward one element, returning false if exhausted,
+// or an error if iteration failed for some reason (e.g. root being iterated
+// becomes stale and garbage collected).
+func (it *binaryIterator) Next() bool {
+ if it.aDone && it.bDone {
+ return false
+ }
+
+first:
+ if it.aDone {
+ it.k = it.b.Hash()
+ it.bDone = !it.b.Next()
+
+ return true
+ }
+
+ if it.bDone {
+ it.k = it.a.Hash()
+ it.aDone = !it.a.Next()
+
+ return true
+ }
+
+ nextA, nextB := it.a.Hash(), it.b.Hash()
+
+ if diff := bytes.Compare(nextA[:], nextB[:]); diff < 0 {
+ it.aDone = !it.a.Next()
+ it.k = nextA
+
+ return true
+ } else if diff == 0 {
+ // Now we need to advance one of them
+ it.aDone = !it.a.Next()
+
+ goto first
+ }
+
+ it.bDone = !it.b.Next()
+ it.k = nextB
+
+ return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (it *binaryIterator) Error() error {
+ return it.fail
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *binaryIterator) Hash() types.Hash {
+ return it.k
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at, or
+// nil if the iterated snapshot stack became stale (you can check Error after
+// to see if it failed or not).
+//
+// Note the returned account is not a copy, please don't modify it.
+func (it *binaryIterator) Account() []byte {
+ if !it.accountIterator {
+ return nil
+ }
+
+ // The topmost iterator must be `diffAccountIterator`
+ blob, err := it.a.(*diffAccountIterator).layer.AccountRLP(it.k)
+ if err != nil {
+ it.fail = err
+
+ return nil
+ }
+
+ return blob
+}
+
+// Slot returns the raw storage slot data the iterator is currently at, or
+// nil if the iterated snapshot stack became stale (you can check Error after
+// to see if it failed or not).
+//
+// Note the returned slot is not a copy, please don't modify it.
+func (it *binaryIterator) Slot() []byte {
+ if it.accountIterator {
+ return nil
+ }
+
+ blob, err := it.a.(*diffStorageIterator).layer.Storage(it.account, it.k)
+ if err != nil {
+ it.fail = err
+
+ return nil
+ }
+
+ return blob
+}
+
+// Release recursively releases all the iterators in the stack.
+func (it *binaryIterator) Release() {
+ it.a.Release()
+ it.b.Release()
+}
+
+// newBinaryAccountIterator creates a simplistic account iterator to step over
+// all the accounts in a slow, but easily verifiable way.
+func (dl *diffLayer) newBinaryAccountIterator() AccountIterator {
+ iter := dl.initBinaryAccountIterator()
+ //nolint:forcetypeassert
+ return iter.(AccountIterator)
+}
+
+// newBinaryStorageIterator creates a simplistic account iterator to step over
+// all the storage slots in a slow, but easily verifiable way.
+func (dl *diffLayer) newBinaryStorageIterator(account types.Hash) StorageIterator {
+ iter := dl.initBinaryStorageIterator(account)
+ //nolint:forcetypeassert
+ return iter.(StorageIterator)
+}
diff --git a/state/snapshot/iterator_fast.go b/state/snapshot/iterator_fast.go
new file mode 100644
index 0000000000..dabd07f559
--- /dev/null
+++ b/state/snapshot/iterator_fast.go
@@ -0,0 +1,378 @@
+package snapshot
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// weightedIterator is a iterator with an assigned weight. It is used to prioritise
+// which account or storage slot is the correct one if multiple iterators find the
+// same one (modified in multiple consecutive blocks).
+type weightedIterator struct {
+ it Iterator
+ priority int
+}
+
+// weightedIterators is a set of iterators implementing the sort.Interface.
+type weightedIterators []*weightedIterator
+
+// Len implements sort.Interface, returning the number of active iterators.
+func (its weightedIterators) Len() int { return len(its) }
+
+// Less implements sort.Interface, returning which of two iterators in the stack
+// is before the other.
+func (its weightedIterators) Less(i, j int) bool {
+ // Order the iterators primarily by the account hashes
+ hashI := its[i].it.Hash()
+ hashJ := its[j].it.Hash()
+
+ switch bytes.Compare(hashI[:], hashJ[:]) {
+ case -1:
+ return true
+ case 1:
+ return false
+ }
+
+ // Same account/storage-slot in multiple layers, split by priority
+ return its[i].priority < its[j].priority
+}
+
+// Swap implements sort.Interface, swapping two entries in the iterator stack.
+func (its weightedIterators) Swap(i, j int) {
+ its[i], its[j] = its[j], its[i]
+}
+
+// fastIterator is a more optimized multi-layer iterator which maintains a
+// direct mapping of all iterators leading down to the bottom layer.
+type fastIterator struct {
+ tree *Tree // Snapshot tree to reinitialize stale sub-iterators with
+ root types.Hash // Root hash to reinitialize stale sub-iterators through
+
+ curAccount []byte
+ curSlot []byte
+
+ iterators weightedIterators
+ initiated bool
+ account bool
+ fail error
+}
+
+// newFastIterator creates a new hierarchical account or storage iterator with one
+// element per diff layer. The returned combo iterator can be used to walk over
+// the entire snapshot diff stack simultaneously.
+func newFastIterator(tree *Tree, root, account, seek types.Hash, accountIterator bool) (*fastIterator, error) {
+ snap := tree.Snapshot(root)
+ if snap == nil {
+ return nil, fmt.Errorf("unknown snapshot: %x", root)
+ }
+
+ fi := &fastIterator{
+ tree: tree,
+ root: root,
+ account: accountIterator,
+ }
+
+ current, _ := snap.(snapshot)
+
+ for depth := 0; current != nil; depth++ {
+ if accountIterator {
+ fi.iterators = append(fi.iterators, &weightedIterator{
+ it: current.AccountIterator(seek),
+ priority: depth,
+ })
+ } else {
+ // If the whole storage is destructed in this layer, don't
+ // bother deeper layer anymore. But we should still keep
+ // the iterator for this layer, since the iterator can contain
+ // some valid slots which belongs to the re-created account.
+ it, destructed := current.StorageIterator(account, seek)
+ fi.iterators = append(fi.iterators, &weightedIterator{
+ it: it,
+ priority: depth,
+ })
+
+ if destructed {
+ break
+ }
+ }
+
+ current = current.Parent()
+ }
+
+ fi.init()
+
+ return fi, nil
+}
+
+// init walks over all the iterators and resolves any clashes between them, after
+// which it prepares the stack for step-by-step iteration.
+func (fi *fastIterator) init() {
+ // Track which account hashes are iterators positioned on
+ var positioned = make(map[types.Hash]int)
+
+ // Position all iterators and track how many remain live
+ for i := 0; i < len(fi.iterators); i++ {
+ // Retrieve the first element and if it clashes with a previous iterator,
+ // advance either the current one or the old one. Repeat until nothing is
+ // clashing any more.
+ it := fi.iterators[i]
+
+ for {
+ // If the iterator is exhausted, drop it off the end
+ if !it.it.Next() {
+ it.it.Release()
+
+ last := len(fi.iterators) - 1
+
+ fi.iterators[i] = fi.iterators[last]
+ fi.iterators[last] = nil
+ fi.iterators = fi.iterators[:last]
+
+ i--
+
+ break
+ }
+
+ // The iterator is still alive, check for collisions with previous ones
+ hash := it.it.Hash()
+ if other, exist := positioned[hash]; !exist {
+ positioned[hash] = i
+
+ break
+ } else {
+ // Iterators collide, one needs to be progressed, use priority to
+ // determine which.
+ //
+ // This whole else-block can be avoided, if we instead
+ // do an initial priority-sort of the iterators. If we do that,
+ // then we'll only wind up here if a lower-priority (preferred) iterator
+ // has the same value, and then we will always just continue.
+ // However, it costs an extra sort, so it's probably not better
+ if fi.iterators[other].priority < it.priority {
+ // The 'it' should be progressed
+ continue
+ } else {
+ // The 'other' should be progressed, swap them
+ it = fi.iterators[other]
+ fi.iterators[other], fi.iterators[i] = fi.iterators[i], fi.iterators[other]
+
+ continue
+ }
+ }
+ }
+ }
+
+ // Re-sort the entire list
+ sort.Sort(fi.iterators)
+
+ fi.initiated = false
+}
+
+// Next steps the iterator forward one element, returning false if exhausted.
+func (fi *fastIterator) Next() bool {
+ if len(fi.iterators) == 0 {
+ return false
+ }
+
+ if !fi.initiated {
+ // Don't forward first time -- we had to 'Next' once in order to
+ // do the sorting already
+ fi.initiated = true
+
+ if fi.account {
+ //nolint:forcetypeassert
+ fi.curAccount = fi.iterators[0].it.(AccountIterator).Account()
+ } else {
+ //nolint:forcetypeassert
+ fi.curSlot = fi.iterators[0].it.(StorageIterator).Slot()
+ }
+
+ if innerErr := fi.iterators[0].it.Error(); innerErr != nil {
+ fi.fail = innerErr
+
+ return false
+ }
+
+ if fi.curAccount != nil || fi.curSlot != nil {
+ // Implicit else: we've hit a nil-account or nil-slot, and need to
+ // fall through to the loop below to land on something non-nil
+ return true
+ }
+ }
+
+ // If an account or a slot is deleted in one of the layers, the key will
+ // still be there, but the actual value will be nil. However, the iterator
+ // should not export nil-values (but instead simply omit the key), so we
+ // need to loop here until we either
+ // - get a non-nil value,
+ // - hit an error,
+ // - or exhaust the iterator
+ for {
+ if !fi.next(0) {
+ return false // exhausted
+ }
+
+ if fi.account {
+ //nolint:forcetypeassert
+ fi.curAccount = fi.iterators[0].it.(AccountIterator).Account()
+ } else {
+ //nolint:forcetypeassert
+ fi.curSlot = fi.iterators[0].it.(StorageIterator).Slot()
+ }
+
+ if innerErr := fi.iterators[0].it.Error(); innerErr != nil {
+ fi.fail = innerErr
+
+ return false // error
+ }
+
+ if fi.curAccount != nil || fi.curSlot != nil {
+ break // non-nil value found
+ }
+ }
+
+ return true
+}
+
+// next handles the next operation internally and should be invoked when we know
+// that two elements in the list may have the same value.
+//
+// For example, if the iterated hashes become [2,3,5,5,8,9,10], then we should
+// invoke next(3), which will call Next on elem 3 (the second '5') and will
+// cascade along the list, applying the same operation if needed.
+func (fi *fastIterator) next(idx int) bool {
+ // If this particular iterator got exhausted, remove it and return true (the
+ // next one is surely not exhausted yet, otherwise it would have been removed
+ // already).
+ if it := fi.iterators[idx].it; !it.Next() {
+ it.Release()
+
+ fi.iterators = append(fi.iterators[:idx], fi.iterators[idx+1:]...)
+
+ return len(fi.iterators) > 0
+ }
+
+ // If there's no one left to cascade into, return
+ if idx == len(fi.iterators)-1 {
+ return true
+ }
+
+ // We next-ed the iterator at 'idx', now we may have to re-sort that element
+ var (
+ cur, next = fi.iterators[idx], fi.iterators[idx+1]
+ curHash, nextHash = cur.it.Hash(), next.it.Hash()
+ )
+
+ if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
+ // It is still in correct place
+ return true
+ } else if diff == 0 && cur.priority < next.priority {
+ // So still in correct place, but we need to iterate on the next
+ fi.next(idx + 1)
+
+ return true
+ }
+
+ // At this point, the iterator is in the wrong location, but the remaining
+ // list is sorted. Find out where to move the item.
+ clash := -1
+ index := sort.Search(len(fi.iterators), func(n int) bool {
+ // The iterator always advances forward, so anything before the old slot
+ // is known to be behind us, so just skip them altogether. This actually
+ // is an important clause since the sort order got invalidated.
+ if n < idx {
+ return false
+ }
+
+ if n == len(fi.iterators)-1 {
+ // Can always place an elem last
+ return true
+ }
+
+ nextHash := fi.iterators[n+1].it.Hash()
+ if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
+ return true
+ } else if diff > 0 {
+ return false
+ }
+ // The elem we're placing it next to has the same value,
+ // so whichever winds up on n+1 will need further iteration
+ clash = n + 1
+
+ return cur.priority < fi.iterators[n+1].priority
+ })
+
+ fi.move(idx, index)
+
+ if clash != -1 {
+ fi.next(clash)
+ }
+
+ return true
+}
+
+// move advances an iterator to another position in the list.
+func (fi *fastIterator) move(index, newpos int) {
+ elem := fi.iterators[index]
+ copy(fi.iterators[index:], fi.iterators[index+1:newpos+1])
+ fi.iterators[newpos] = elem
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+func (fi *fastIterator) Error() error {
+ return fi.fail
+}
+
+// Hash returns the current key
+func (fi *fastIterator) Hash() types.Hash {
+ return fi.iterators[0].it.Hash()
+}
+
+// Account returns the current account blob.
+// Note the returned account is not a copy, please don't modify it.
+func (fi *fastIterator) Account() []byte {
+ return fi.curAccount
+}
+
+// Slot returns the current storage slot.
+// Note the returned slot is not a copy, please don't modify it.
+func (fi *fastIterator) Slot() []byte {
+ return fi.curSlot
+}
+
+// Release iterates over all the remaining live layer iterators and releases each
+// of them individually.
+func (fi *fastIterator) Release() {
+ for _, it := range fi.iterators {
+ it.it.Release()
+ }
+
+ fi.iterators = nil
+}
+
+// Debug is a convenience helper during testing
+func (fi *fastIterator) Debug() {
+ for _, it := range fi.iterators {
+ fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Hash()[0])
+ }
+
+ fmt.Println()
+}
+
+// newFastAccountIterator creates a new hierarchical account iterator with one
+// element per diff layer. The returned combo iterator can be used to walk over
+// the entire snapshot diff stack simultaneously.
+func newFastAccountIterator(tree *Tree, root, seek types.Hash) (AccountIterator, error) {
+ return newFastIterator(tree, root, types.Hash{}, seek, true)
+}
+
+// newFastStorageIterator creates a new hierarchical storage iterator with one
+// element per diff layer. The returned combo iterator can be used to walk over
+// the entire snapshot diff stack simultaneously.
+func newFastStorageIterator(tree *Tree, root, account, seek types.Hash) (StorageIterator, error) {
+ return newFastIterator(tree, root, account, seek, false)
+}
diff --git a/state/snapshot/iterator_test.go b/state/snapshot/iterator_test.go
new file mode 100644
index 0000000000..6e0960f827
--- /dev/null
+++ b/state/snapshot/iterator_test.go
@@ -0,0 +1,1085 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+// TestAccountIteratorBasics tests some simple single-layer(diff and disk) iteration
+func TestAccountIteratorBasics(t *testing.T) {
+ var (
+ destructs = make(map[types.Hash]struct{})
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+ // Fill up a parent
+ for i := 0; i < 100; i++ {
+ h := randomHash()
+ data := randomAccount()
+
+ accounts[h] = data
+ if rand.Intn(4) == 0 {
+ destructs[h] = struct{}{}
+ }
+ if rand.Intn(2) == 0 {
+ accStorage := make(map[types.Hash][]byte)
+ value := make([]byte, 32)
+ rand.Read(value)
+ accStorage[randomHash()] = value
+ storage[h] = accStorage
+ }
+ }
+ // Add some (identical) layers on top
+ diffLayer := newDiffLayer(emptyLayer(), types.Hash{}, copyDestructs(destructs), copyAccounts(accounts), copyStorage(storage), hclog.NewNullLogger(), NilMetrics())
+ it := diffLayer.AccountIterator(types.Hash{})
+ verifyIterator(t, 100, it, verifyNothing) // Nil is allowed for single layer iterator
+
+ diskLayer := diffToDisk(diffLayer)
+ it = diskLayer.AccountIterator(types.Hash{})
+ verifyIterator(t, 100, it, verifyNothing) // Nil is allowed for single layer iterator
+}
+
+// TestStorageIteratorBasics tests some simple single-layer(diff and disk) iteration for storage
+func TestStorageIteratorBasics(t *testing.T) {
+ var (
+ nilStorage = make(map[types.Hash]int)
+ accounts = make(map[types.Hash][]byte)
+ storage = make(map[types.Hash]map[types.Hash][]byte)
+ )
+ // Fill some random data
+ for i := 0; i < 10; i++ {
+ h := randomHash()
+ accounts[h] = randomAccount()
+
+ accStorage := make(map[types.Hash][]byte)
+ value := make([]byte, 32)
+
+ var nilstorage int
+ for i := 0; i < 100; i++ {
+ rand.Read(value)
+ if rand.Intn(2) == 0 {
+ accStorage[randomHash()] = types.CopyBytes(value)
+ } else {
+ accStorage[randomHash()] = nil // delete slot
+ nilstorage += 1
+ }
+ }
+ storage[h] = accStorage
+ nilStorage[h] = nilstorage
+ }
+ // Add some (identical) layers on top
+ diffLayer := newDiffLayer(emptyLayer(), types.Hash{}, nil, copyAccounts(accounts),
+ copyStorage(storage), hclog.NewNullLogger(), NilMetrics())
+ for account := range accounts {
+ it, _ := diffLayer.StorageIterator(account, types.Hash{})
+ verifyIterator(t, 100, it, verifyNothing) // Nil is allowed for single layer iterator
+ }
+
+ diskLayer := diffToDisk(diffLayer)
+ for account := range accounts {
+ it, _ := diskLayer.StorageIterator(account, types.Hash{})
+ verifyIterator(t, 100-nilStorage[account], it, verifyNothing) // Nil is allowed for single layer iterator
+ }
+}
+
+type testIterator struct {
+ values []byte
+}
+
+func newTestIterator(values ...byte) *testIterator {
+ return &testIterator{values}
+}
+
+func (ti *testIterator) Seek(types.Hash) {
+ panic("implement me")
+}
+
+func (ti *testIterator) Next() bool {
+ ti.values = ti.values[1:]
+ return len(ti.values) > 0
+}
+
+func (ti *testIterator) Error() error {
+ return nil
+}
+
+func (ti *testIterator) Hash() types.Hash {
+ return types.BytesToHash([]byte{ti.values[0]})
+}
+
+func (ti *testIterator) Account() []byte {
+ return nil
+}
+
+func (ti *testIterator) Slot() []byte {
+ return nil
+}
+
+func (ti *testIterator) Release() {}
+
+func TestFastIteratorBasics(t *testing.T) {
+ type testCase struct {
+ lists [][]byte
+ expKeys []byte
+ }
+ for i, tc := range []testCase{
+ {lists: [][]byte{{0, 1, 8}, {1, 2, 8}, {2, 9}, {4},
+ {7, 14, 15}, {9, 13, 15, 16}},
+ expKeys: []byte{0, 1, 2, 4, 7, 8, 9, 13, 14, 15, 16}},
+ {lists: [][]byte{{0, 8}, {1, 2, 8}, {7, 14, 15}, {8, 9},
+ {9, 10}, {10, 13, 15, 16}},
+ expKeys: []byte{0, 1, 2, 7, 8, 9, 10, 13, 14, 15, 16}},
+ } {
+ var iterators []*weightedIterator
+ for i, data := range tc.lists {
+ it := newTestIterator(data...)
+ iterators = append(iterators, &weightedIterator{it, i})
+ }
+ fi := &fastIterator{
+ iterators: iterators,
+ initiated: false,
+ }
+ count := 0
+ for fi.Next() {
+ if got, exp := fi.Hash()[31], tc.expKeys[count]; exp != got {
+ t.Errorf("tc %d, [%d]: got %d exp %d", i, count, got, exp)
+ }
+ count++
+ }
+ }
+}
+
+type verifyContent int
+
+const (
+ verifyNothing verifyContent = iota
+ verifyAccount
+ verifyStorage
+)
+
+func verifyIterator(t *testing.T, expCount int, it Iterator, verify verifyContent) {
+ t.Helper()
+
+ var (
+ count = 0
+ last = types.Hash{}
+ )
+ for it.Next() {
+ hash := it.Hash()
+ if bytes.Compare(last[:], hash[:]) >= 0 {
+ t.Errorf("wrong order: %x >= %x", last, hash)
+ }
+ count++
+ if verify == verifyAccount && len(it.(AccountIterator).Account()) == 0 {
+ t.Errorf("iterator returned nil-value for hash %x", hash)
+ } else if verify == verifyStorage && len(it.(StorageIterator).Slot()) == 0 {
+ t.Errorf("iterator returned nil-value for hash %x", hash)
+ }
+ last = hash
+ }
+ if count != expCount {
+ t.Errorf("iterator count mismatch: have %d, want %d", count, expCount)
+ }
+ if err := it.Error(); err != nil {
+ t.Errorf("iterator failed: %v", err)
+ }
+}
+
+// TestAccountIteratorTraversal tests some simple multi-layer iteration.
+func TestAccountIteratorTraversal(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil, logger)
+
+ // Verify the single and multi-layer iterators
+ head := snaps.Snapshot(types.StringToHash("0x04"))
+
+ verifyIterator(t, 3, head.(snapshot).AccountIterator(types.Hash{}), verifyNothing)
+ verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator(), verifyAccount)
+
+ it, _ := snaps.AccountIterator(types.StringToHash("0x04"), types.Hash{})
+ verifyIterator(t, 7, it, verifyAccount)
+ it.Release()
+
+ // Test after persist some bottom-most layers into the disk,
+ // the functionalities still work.
+ limit := aggregatorMemoryLimit
+ defer func() {
+ aggregatorMemoryLimit = limit
+ }()
+ aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+ snaps.Cap(types.StringToHash("0x04"), 2)
+ verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator(), verifyAccount)
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.Hash{})
+ verifyIterator(t, 7, it, verifyAccount)
+ it.Release()
+}
+
+func TestStorageIteratorTraversal(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x02", "0x03"}}, nil), logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x04", "0x05", "0x06"}}, nil), logger)
+
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x02", "0x03"}}, nil), logger)
+
+ // Verify the single and multi-layer iterators
+ head := snaps.Snapshot(types.StringToHash("0x04"))
+
+ diffIter, _ := head.(snapshot).StorageIterator(types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 3, diffIter, verifyNothing)
+ verifyIterator(t, 6, head.(*diffLayer).newBinaryStorageIterator(types.StringToHash("0xaa")), verifyStorage)
+
+ it, _ := snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 6, it, verifyStorage)
+ it.Release()
+
+ // Test after persist some bottom-most layers into the disk,
+ // the functionalities still work.
+ limit := aggregatorMemoryLimit
+ defer func() {
+ aggregatorMemoryLimit = limit
+ }()
+ aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+ snaps.Cap(types.StringToHash("0x04"), 2)
+ verifyIterator(t, 6, head.(*diffLayer).newBinaryStorageIterator(types.StringToHash("0xaa")), verifyStorage)
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 6, it, verifyStorage)
+ it.Release()
+}
+
+// TestAccountIteratorTraversalValues tests some multi-layer iteration, where we
+// also expect the correct values to show up.
+func TestAccountIteratorTraversalValues(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Create a batch of account sets to seed subsequent layers with
+ var (
+ a = make(map[types.Hash][]byte)
+ b = make(map[types.Hash][]byte)
+ c = make(map[types.Hash][]byte)
+ d = make(map[types.Hash][]byte)
+ e = make(map[types.Hash][]byte)
+ f = make(map[types.Hash][]byte)
+ g = make(map[types.Hash][]byte)
+ h = make(map[types.Hash][]byte)
+ )
+
+ for i := byte(2); i < 0xff; i++ {
+ a[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i))
+ if i > 20 && i%2 == 0 {
+ b[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 1, i))
+ }
+ if i%4 == 0 {
+ c[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 2, i))
+ }
+ if i%7 == 0 {
+ d[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 3, i))
+ }
+ if i%8 == 0 {
+ e[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 4, i))
+ }
+ if i > 50 || i < 85 {
+ f[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 5, i))
+ }
+ if i%64 == 0 {
+ g[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 6, i))
+ }
+ if i%128 == 0 {
+ h[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i))
+ }
+ }
+ // Assemble a stack of snapshots from the account layers
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil, a, nil, logger)
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil, b, nil, logger)
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil, c, nil, logger)
+ snaps.Update(types.StringToHash("0x05"), types.StringToHash("0x04"), nil, d, nil, logger)
+ snaps.Update(types.StringToHash("0x06"), types.StringToHash("0x05"), nil, e, nil, logger)
+ snaps.Update(types.StringToHash("0x07"), types.StringToHash("0x06"), nil, f, nil, logger)
+ snaps.Update(types.StringToHash("0x08"), types.StringToHash("0x07"), nil, g, nil, logger)
+ snaps.Update(types.StringToHash("0x09"), types.StringToHash("0x08"), nil, h, nil, logger)
+
+ it, _ := snaps.AccountIterator(types.StringToHash("0x09"), types.Hash{})
+ head := snaps.Snapshot(types.StringToHash("0x09"))
+ for it.Next() {
+ hash := it.Hash()
+ want, err := head.AccountRLP(hash)
+ if err != nil {
+ t.Fatalf("failed to retrieve expected account: %v", err)
+ }
+ if have := it.Account(); !bytes.Equal(want, have) {
+ t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
+ }
+ }
+ it.Release()
+
+ // Test after persist some bottom-most layers into the disk,
+ // the functionalities still work.
+ limit := aggregatorMemoryLimit
+ defer func() {
+ aggregatorMemoryLimit = limit
+ }()
+ aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+ snaps.Cap(types.StringToHash("0x09"), 2)
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x09"), types.Hash{})
+ for it.Next() {
+ hash := it.Hash()
+ want, err := head.AccountRLP(hash)
+ if err != nil {
+ t.Fatalf("failed to retrieve expected account: %v", err)
+ }
+ if have := it.Account(); !bytes.Equal(want, have) {
+ t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
+ }
+ }
+ it.Release()
+}
+
+func TestStorageIteratorTraversalValues(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ wrapStorage := func(storage map[types.Hash][]byte) map[types.Hash]map[types.Hash][]byte {
+ return map[types.Hash]map[types.Hash][]byte{
+ types.StringToHash("0xaa"): storage,
+ }
+ }
+ // Create a batch of storage sets to seed subsequent layers with
+ var (
+ a = make(map[types.Hash][]byte)
+ b = make(map[types.Hash][]byte)
+ c = make(map[types.Hash][]byte)
+ d = make(map[types.Hash][]byte)
+ e = make(map[types.Hash][]byte)
+ f = make(map[types.Hash][]byte)
+ g = make(map[types.Hash][]byte)
+ h = make(map[types.Hash][]byte)
+ )
+
+ for i := byte(2); i < 0xff; i++ {
+ a[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i))
+ if i > 20 && i%2 == 0 {
+ b[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 1, i))
+ }
+ if i%4 == 0 {
+ c[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 2, i))
+ }
+ if i%7 == 0 {
+ d[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 3, i))
+ }
+ if i%8 == 0 {
+ e[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 4, i))
+ }
+ if i > 50 || i < 85 {
+ f[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 5, i))
+ }
+ if i%64 == 0 {
+ g[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 6, i))
+ }
+ if i%128 == 0 {
+ h[types.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i))
+ }
+ }
+ // Assemble a stack of snapshots from the account layers
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil, randomAccountSet("0xaa"), wrapStorage(a), logger)
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil, randomAccountSet("0xaa"), wrapStorage(b), logger)
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil, randomAccountSet("0xaa"), wrapStorage(c), logger)
+ snaps.Update(types.StringToHash("0x05"), types.StringToHash("0x04"), nil, randomAccountSet("0xaa"), wrapStorage(d), logger)
+ snaps.Update(types.StringToHash("0x06"), types.StringToHash("0x05"), nil, randomAccountSet("0xaa"), wrapStorage(e), logger)
+ snaps.Update(types.StringToHash("0x07"), types.StringToHash("0x06"), nil, randomAccountSet("0xaa"), wrapStorage(e), logger)
+ snaps.Update(types.StringToHash("0x08"), types.StringToHash("0x07"), nil, randomAccountSet("0xaa"), wrapStorage(g), logger)
+ snaps.Update(types.StringToHash("0x09"), types.StringToHash("0x08"), nil, randomAccountSet("0xaa"), wrapStorage(h), logger)
+
+ it, _ := snaps.StorageIterator(types.StringToHash("0x09"), types.StringToHash("0xaa"), types.Hash{})
+ head := snaps.Snapshot(types.StringToHash("0x09"))
+ for it.Next() {
+ hash := it.Hash()
+ want, err := head.Storage(types.StringToHash("0xaa"), hash)
+ if err != nil {
+ t.Fatalf("failed to retrieve expected storage slot: %v", err)
+ }
+ if have := it.Slot(); !bytes.Equal(want, have) {
+ t.Fatalf("hash %x: slot mismatch: have %x, want %x", hash, have, want)
+ }
+ }
+ it.Release()
+
+ // Test after persist some bottom-most layers into the disk,
+ // the functionalities still work.
+ limit := aggregatorMemoryLimit
+ defer func() {
+ aggregatorMemoryLimit = limit
+ }()
+ aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+ snaps.Cap(types.StringToHash("0x09"), 2)
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x09"), types.StringToHash("0xaa"), types.Hash{})
+ for it.Next() {
+ hash := it.Hash()
+ want, err := head.Storage(types.StringToHash("0xaa"), hash)
+ if err != nil {
+ t.Fatalf("failed to retrieve expected slot: %v", err)
+ }
+ if have := it.Slot(); !bytes.Equal(want, have) {
+ t.Fatalf("hash %x: slot mismatch: have %x, want %x", hash, have, want)
+ }
+ }
+ it.Release()
+}
+
+// This testcase is notorious, all layers contain the exact same 200 accounts.
+func TestAccountIteratorLargeTraversal(t *testing.T) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[types.Hash][]byte {
+ accounts := make(map[types.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := types.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ for i := 1; i < 128; i++ {
+ snaps.Update(types.StringToHash(fmt.Sprintf("0x%02x", i+1)), types.StringToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil, logger)
+ }
+ // Iterate the entire stack and ensure everything is hit only once
+ head := snaps.Snapshot(types.StringToHash("0x80"))
+ verifyIterator(t, 200, head.(snapshot).AccountIterator(types.Hash{}), verifyNothing)
+ verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator(), verifyAccount)
+
+ it, _ := snaps.AccountIterator(types.StringToHash("0x80"), types.Hash{})
+ verifyIterator(t, 200, it, verifyAccount)
+ it.Release()
+
+ // Test after persist some bottom-most layers into the disk,
+ // the functionalities still work.
+ limit := aggregatorMemoryLimit
+ defer func() {
+ aggregatorMemoryLimit = limit
+ }()
+ aggregatorMemoryLimit = 0 // Force pushing the bottom-most layer into disk
+ snaps.Cap(types.StringToHash("0x80"), 2)
+
+ verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator(), verifyAccount)
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x80"), types.Hash{})
+ verifyIterator(t, 200, it, verifyAccount)
+ it.Release()
+}
+
+// TestAccountIteratorFlattening tests what happens when we
+// - have a live iterator on child C (parent C1 -> C2 .. CN)
+// - flattens C2 all the way into CN
+// - continues iterating
+func TestAccountIteratorFlattening(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Create a stack of diffs on top
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil, logger)
+
+ // Create an iterator and flatten the data from underneath it
+ it, _ := snaps.AccountIterator(types.StringToHash("0x04"), types.Hash{})
+ defer it.Release()
+
+ if err := snaps.Cap(types.StringToHash("0x04"), 1); err != nil {
+ t.Fatalf("failed to flatten snapshot stack: %v", err)
+ }
+ //verifyIterator(t, 7, it)
+}
+
+func TestAccountIteratorSeek(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create a snapshot stack with some initial data
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xbb", "0xdd", "0xf0"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil,
+ randomAccountSet("0xcc", "0xf0", "0xff"), nil, logger)
+
+ // Account set is now
+ // 02: aa, ee, f0, ff
+ // 03: aa, bb, dd, ee, f0 (, f0), ff
+ // 04: aa, bb, cc, dd, ee, f0 (, f0), ff (, ff)
+ // Construct various iterators and ensure their traversal is correct
+ it, _ := snaps.AccountIterator(types.StringToHash("0x02"), types.StringToHash("0xdd"))
+ defer it.Release()
+ verifyIterator(t, 3, it, verifyAccount) // expected: ee, f0, ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x02"), types.StringToHash("0xaa"))
+ defer it.Release()
+ verifyIterator(t, 4, it, verifyAccount) // expected: aa, ee, f0, ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x02"), types.StringToHash("0xff"))
+ defer it.Release()
+ verifyIterator(t, 1, it, verifyAccount) // expected: ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x02"), types.StringToHash("0xff1"))
+ defer it.Release()
+ verifyIterator(t, 0, it, verifyAccount) // expected: nothing
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.StringToHash("0xbb"))
+ defer it.Release()
+ verifyIterator(t, 6, it, verifyAccount) // expected: bb, cc, dd, ee, f0, ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.StringToHash("0xef"))
+ defer it.Release()
+ verifyIterator(t, 2, it, verifyAccount) // expected: f0, ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.StringToHash("0xf0"))
+ defer it.Release()
+ verifyIterator(t, 2, it, verifyAccount) // expected: f0, ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.StringToHash("0xff"))
+ defer it.Release()
+ verifyIterator(t, 1, it, verifyAccount) // expected: ff
+
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.StringToHash("0xff1"))
+ defer it.Release()
+ verifyIterator(t, 0, it, verifyAccount) // expected: nothing
+}
+
+func TestStorageIteratorSeek(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create a snapshot stack with some initial data
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x03", "0x05"}}, nil), logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x02", "0x05", "0x06"}}, nil), logger)
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x05", "0x08"}}, nil), logger)
+
+ // Account set is now
+ // 02: 01, 03, 05
+ // 03: 01, 02, 03, 05 (, 05), 06
+ // 04: 01(, 01), 02, 03, 05(, 05, 05), 06, 08
+ // Construct various iterators and ensure their traversal is correct
+ it, _ := snaps.StorageIterator(types.StringToHash("0x02"), types.StringToHash("0xaa"), types.StringToHash("0x01"))
+ defer it.Release()
+ verifyIterator(t, 3, it, verifyStorage) // expected: 01, 03, 05
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x02"), types.StringToHash("0xaa"), types.StringToHash("0x02"))
+ defer it.Release()
+ verifyIterator(t, 2, it, verifyStorage) // expected: 03, 05
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x02"), types.StringToHash("0xaa"), types.StringToHash("0x5"))
+ defer it.Release()
+ verifyIterator(t, 1, it, verifyStorage) // expected: 05
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x02"), types.StringToHash("0xaa"), types.StringToHash("0x6"))
+ defer it.Release()
+ verifyIterator(t, 0, it, verifyStorage) // expected: nothing
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.StringToHash("0x01"))
+ defer it.Release()
+ verifyIterator(t, 6, it, verifyStorage) // expected: 01, 02, 03, 05, 06, 08
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.StringToHash("0x05"))
+ defer it.Release()
+ verifyIterator(t, 3, it, verifyStorage) // expected: 05, 06, 08
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.StringToHash("0x08"))
+ defer it.Release()
+ verifyIterator(t, 1, it, verifyStorage) // expected: 08
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.StringToHash("0x09"))
+ defer it.Release()
+ verifyIterator(t, 0, it, verifyStorage) // expected: nothing
+}
+
+// TestAccountIteratorDeletions tests that the iterator behaves correct when there are
+// deleted accounts (where the Account() value is nil). The iterator
+// should not output any accounts or nil-values for those cases.
+func TestAccountIteratorDeletions(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"),
+ nil, randomAccountSet("0x11", "0x22", "0x33"), nil, logger)
+
+ deleted := types.StringToHash("0x22")
+ destructed := map[types.Hash]struct{}{
+ deleted: {},
+ }
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"),
+ destructed, randomAccountSet("0x11", "0x33"), nil, logger)
+
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"),
+ nil, randomAccountSet("0x33", "0x44", "0x55"), nil, logger)
+
+ // The output should be 11,33,44,55
+ it, _ := snaps.AccountIterator(types.StringToHash("0x04"), types.Hash{})
+ // Do a quick check
+ verifyIterator(t, 4, it, verifyAccount)
+ it.Release()
+
+ // And a more detailed verification that we indeed do not see '0x22'
+ it, _ = snaps.AccountIterator(types.StringToHash("0x04"), types.Hash{})
+ defer it.Release()
+ for it.Next() {
+ hash := it.Hash()
+ if it.Account() == nil {
+ t.Errorf("iterator returned nil-value for hash %x", hash)
+ }
+ if hash == deleted {
+ t.Errorf("expected deleted elem %x to not be returned by iterator", deleted)
+ }
+ }
+}
+
+func TestStorageIteratorDeletions(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Create an empty base layer and a snapshot tree out of it
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ // Stack three diff layers on top with various overlaps
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x01", "0x03", "0x05"}}, nil), logger)
+
+ snaps.Update(types.StringToHash("0x03"), types.StringToHash("0x02"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x02", "0x04", "0x06"}}, [][]string{{"0x01", "0x03"}}), logger)
+
+ // The output should be 02,04,05,06
+ it, _ := snaps.StorageIterator(types.StringToHash("0x03"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 4, it, verifyStorage)
+ it.Release()
+
+ // The output should be 04,05,06
+ it, _ = snaps.StorageIterator(types.StringToHash("0x03"), types.StringToHash("0xaa"), types.StringToHash("0x03"))
+ verifyIterator(t, 3, it, verifyStorage)
+ it.Release()
+
+ // Destruct the whole storage
+ destructed := map[types.Hash]struct{}{
+ types.StringToHash("0xaa"): {},
+ }
+ snaps.Update(types.StringToHash("0x04"), types.StringToHash("0x03"), destructed, nil, nil, logger)
+
+ it, _ = snaps.StorageIterator(types.StringToHash("0x04"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 0, it, verifyStorage)
+ it.Release()
+
+ // Re-insert the slots of the same account
+ snaps.Update(types.StringToHash("0x05"), types.StringToHash("0x04"), nil,
+ randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x07", "0x08", "0x09"}}, nil), logger)
+
+ // The output should be 07,08,09
+ it, _ = snaps.StorageIterator(types.StringToHash("0x05"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 3, it, verifyStorage)
+ it.Release()
+
+ // Destruct the whole storage but re-create the account in the same layer
+ snaps.Update(types.StringToHash("0x06"), types.StringToHash("0x05"), destructed, randomAccountSet("0xaa"), randomStorageSet([]string{"0xaa"}, [][]string{{"0x11", "0x12"}}, nil), logger)
+ it, _ = snaps.StorageIterator(types.StringToHash("0x06"), types.StringToHash("0xaa"), types.Hash{})
+ verifyIterator(t, 2, it, verifyStorage) // The output should be 11,12
+ it.Release()
+
+ verifyIterator(t, 2, snaps.Snapshot(types.StringToHash("0x06")).(*diffLayer).newBinaryStorageIterator(types.StringToHash("0xaa")), verifyStorage)
+}
+
+// BenchmarkAccountIteratorTraversal is a bit a bit notorious -- all layers contain the
+// exact same 200 accounts. That means that we need to process 2000 items, but
+// only spit out 200 values eventually.
+//
+// The value-fetching benchmark is easy on the binary iterator, since it never has to reach
+// down at any depth for retrieving the values -- all are on the topmost layer
+//
+// BenchmarkAccountIteratorTraversal/binary_iterator_keys-6 2239 483674 ns/op
+// BenchmarkAccountIteratorTraversal/binary_iterator_values-6 2403 501810 ns/op
+// BenchmarkAccountIteratorTraversal/fast_iterator_keys-6 1923 677966 ns/op
+// BenchmarkAccountIteratorTraversal/fast_iterator_values-6 1741 649967 ns/op
+func BenchmarkAccountIteratorTraversal(b *testing.B) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[types.Hash][]byte {
+ accounts := make(map[types.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := types.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ for i := 1; i <= 100; i++ {
+ snaps.Update(types.StringToHash(fmt.Sprintf("0x%02x", i+1)), types.StringToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(200), nil, logger)
+ }
+ // We call this once before the benchmark, so the creation of
+ // sorted accountlists are not included in the results.
+ head := snaps.Snapshot(types.StringToHash("0x65"))
+ head.(*diffLayer).newBinaryAccountIterator()
+
+ b.Run("binary iterator keys", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("binary iterator values", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ head.(*diffLayer).accountRLP(it.Hash(), 0)
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator keys", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(types.StringToHash("0x65"), types.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator values", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(types.StringToHash("0x65"), types.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ it.Account()
+ }
+ if exp := 200; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+}
+
+// BenchmarkAccountIteratorLargeBaselayer is a pretty realistic benchmark, where
+// the baselayer is a lot larger than the upper layer.
+//
+// This is heavy on the binary iterator, which in most cases will have to
+// call recursively 100 times for the majority of the values
+//
+// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(keys)-6 514 1971999 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/binary_iterator_(values)-6 61 18997492 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(keys)-6 10000 114385 ns/op
+// BenchmarkAccountIteratorLargeBaselayer/fast_iterator_(values)-6 4047 296823 ns/op
+func BenchmarkAccountIteratorLargeBaselayer(b *testing.B) {
+ // Create a custom account factory to recreate the same addresses
+ makeAccounts := func(num int) map[types.Hash][]byte {
+ accounts := make(map[types.Hash][]byte)
+ for i := 0; i < num; i++ {
+ h := types.Hash{}
+ binary.BigEndian.PutUint64(h[:], uint64(i+1))
+ accounts[h] = randomAccount()
+ }
+ return accounts
+ }
+ logger := hclog.NewNullLogger()
+ metrics := NilMetrics()
+ // Build up a large stack of snapshots
+ base := &diskLayer{
+ diskdb: rawdb.NewMemoryDatabase(),
+ root: types.StringToHash("0x01"),
+ cache: fastcache.New(1024 * 500),
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps := &Tree{
+ layers: map[types.Hash]snapshot{
+ base.root: base,
+ },
+ logger: logger,
+ snapmetrics: metrics,
+ }
+ snaps.Update(types.StringToHash("0x02"), types.StringToHash("0x01"), nil, makeAccounts(2000), nil, logger)
+ for i := 2; i <= 100; i++ {
+ snaps.Update(types.StringToHash(fmt.Sprintf("0x%02x", i+1)), types.StringToHash(fmt.Sprintf("0x%02x", i)), nil, makeAccounts(20), nil, logger)
+ }
+ // We call this once before the benchmark, so the creation of
+ // sorted accountlists are not included in the results.
+ head := snaps.Snapshot(types.StringToHash("0x65"))
+ head.(*diffLayer).newBinaryAccountIterator()
+
+ b.Run("binary iterator (keys)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("binary iterator (values)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ got := 0
+ it := head.(*diffLayer).newBinaryAccountIterator()
+ for it.Next() {
+ got++
+ v := it.Hash()
+ head.(*diffLayer).accountRLP(v, 0)
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator (keys)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(types.StringToHash("0x65"), types.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+ b.Run("fast iterator (values)", func(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ it, _ := snaps.AccountIterator(types.StringToHash("0x65"), types.Hash{})
+ defer it.Release()
+
+ got := 0
+ for it.Next() {
+ it.Account()
+ got++
+ }
+ if exp := 2000; got != exp {
+ b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
+ }
+ }
+ })
+}
diff --git a/state/snapshot/journal.go b/state/snapshot/journal.go
new file mode 100644
index 0000000000..b398e8e686
--- /dev/null
+++ b/state/snapshot/journal.go
@@ -0,0 +1,407 @@
+package snapshot
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+const journalVersion uint64 = 0
+
+// journalGenerator is a disk layer entry containing the generator progress marker.
+type journalGenerator struct {
+ // Indicator that whether the database was in progress of being wiped.
+ // It's deprecated but keep it here for background compatibility.
+ Wiping bool
+
+ Done bool // Whether the generator finished creating the snapshot
+ Marker []byte
+ Accounts uint64
+ Slots uint64
+ Storage uint64
+}
+
+// journalDestruct is an account deletion entry in a diffLayer's disk journal.
+type journalDestruct struct {
+ Hash types.Hash
+}
+
+// journalAccount is an account entry in a diffLayer's disk journal.
+type journalAccount struct {
+ Hash types.Hash
+ Blob []byte
+}
+
+// journalStorage is an account's storage map in a diffLayer's disk journal.
+type journalStorage struct {
+ Hash types.Hash
+ Keys []types.Hash
+ Vals [][]byte
+}
+
+// Journal writes the memory layer contents into a buffer to be stored in the
+// database as the snapshot journal.
+func (dl *diffLayer) Journal(buffer *bytes.Buffer) (types.Hash, error) {
+ // Journal the parent first
+ base, err := dl.parent.Journal(buffer)
+ if err != nil {
+ return types.Hash{}, err
+ }
+
+ // Ensure the layer didn't get stale
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.Stale() {
+ return types.Hash{}, ErrSnapshotStale
+ }
+
+ // Everything below was journalled, persist this layer too
+ if err := rlp.Encode(buffer, dl.root); err != nil {
+ return types.Hash{}, err
+ }
+
+ destructs := make([]journalDestruct, 0, len(dl.destructSet))
+ for hash := range dl.destructSet {
+ destructs = append(destructs, journalDestruct{Hash: hash})
+ }
+
+ if err := rlp.Encode(buffer, destructs); err != nil {
+ return types.Hash{}, err
+ }
+
+ accounts := make([]journalAccount, 0, len(dl.accountData))
+ for hash, blob := range dl.accountData {
+ accounts = append(accounts, journalAccount{Hash: hash, Blob: blob})
+ }
+
+ if err := rlp.Encode(buffer, accounts); err != nil {
+ return types.Hash{}, err
+ }
+
+ storage := make([]journalStorage, 0, len(dl.storageData))
+
+ for hash, slots := range dl.storageData {
+ keys := make([]types.Hash, 0, len(slots))
+ vals := make([][]byte, 0, len(slots))
+
+ for key, val := range slots {
+ keys = append(keys, key)
+ vals = append(vals, val)
+ }
+
+ storage = append(storage, journalStorage{Hash: hash, Keys: keys, Vals: vals})
+ }
+
+ if err := rlp.Encode(buffer, storage); err != nil {
+ return types.Hash{}, err
+ }
+
+ dl.logger.Debug("Journalled diff layer", "root", dl.root, "parent", dl.parent.Root())
+
+ return base, nil
+}
+
+// Journal terminates any in-progress snapshot generation, also implicitly pushing
+// the progress into the database.
+func (dl *diskLayer) Journal(buffer *bytes.Buffer) (types.Hash, error) {
+ // If the snapshot is currently being generated, abort it
+ var stats *generatorStats
+
+ if dl.genAbort != nil {
+ abort := make(chan *generatorStats)
+ dl.genAbort <- abort
+
+ if stats = <-abort; stats != nil {
+ stats.Log("Journalling in-progress snapshot", dl.root, dl.genMarker)
+ }
+ }
+
+ // Ensure the layer didn't get stale
+ dl.lock.RLock()
+ defer dl.lock.RUnlock()
+
+ if dl.stale {
+ return types.Hash{}, ErrSnapshotStale
+ }
+
+ // Ensure the generator stats is written even if none was ran this cycle
+ journalProgress(dl.diskdb, dl.genMarker, stats, dl.logger)
+
+ dl.logger.Debug("Journalled disk layer", "root", dl.root)
+
+ return dl.root, nil
+}
+
+// loadSnapshot loads a pre-existing state snapshot backed by a key-value store.
+func loadSnapshot(
+ logger kvdb.Logger,
+ metrics *Metrics,
+ diskdb kvdb.KVBatchStorage,
+ triedb *trie.Database,
+ root types.Hash,
+ cache int,
+ recovery bool,
+ noBuild bool,
+) (snapshot, bool, error) {
+ // If snapshotting is disabled (initial sync in progress), don't do anything,
+ // wait for the chain to permit us to do something meaningful
+ if rawdb.ReadSnapshotDisabled(diskdb) {
+ return nil, true, nil
+ }
+
+ // Retrieve the block number and hash of the snapshot, failing if no snapshot
+ // is present in the database (or crashed mid-update).
+ baseRoot := rawdb.ReadSnapshotRoot(diskdb)
+ if baseRoot == (types.Hash{}) {
+ return nil, false, errors.New("missing or corrupted snapshot")
+ }
+
+ base := &diskLayer{
+ diskdb: diskdb,
+ triedb: triedb,
+ cache: fastcache.New(cache * 1024 * 1024),
+ root: baseRoot,
+ logger: logger,
+ snapmetrics: metrics,
+ }
+
+ snapshot, generator, err := loadAndParseJournal(logger, diskdb, base)
+ if err != nil {
+ logger.Warn("Failed to load journal", "error", err)
+
+ return nil, false, err
+ }
+ // Entire snapshot journal loaded, sanity check the head. If the loaded
+ // snapshot is not matched with current state root, print a warning log
+ // or discard the entire snapshot it's legacy snapshot.
+ //
+ // Possible scenario: Geth was crashed without persisting journal and then
+ // restart, the head is rewound to the point with available state(trie)
+ // which is below the snapshot. In this case the snapshot can be recovered
+ // by re-executing blocks but right now it's unavailable.
+ if head := snapshot.Root(); head != root {
+ // If it's legacy snapshot, or it's new-format snapshot but
+ // it's not in recovery mode, returns the error here for
+ // rebuilding the entire snapshot forcibly.
+ if !recovery {
+ return nil, false, fmt.Errorf("head doesn't match snapshot: have %s, want %s", head, root)
+ }
+ // It's in snapshot recovery, the assumption is held that
+ // the disk layer is always higher than chain head. It can
+ // be eventually recovered when the chain head beyonds the
+ // disk layer.
+ logger.Warn("Snapshot is not continuous with chain", "snaproot", head, "chainroot", root)
+ }
+ // Load the disk layer status from the generator if it's not complete
+ if !generator.Done {
+ base.genMarker = generator.Marker
+ if base.genMarker == nil {
+ base.genMarker = []byte{}
+ }
+ }
+
+ // Everything loaded correctly, resume any suspended operations
+ // if the background generation is allowed
+ if !generator.Done && !noBuild {
+ base.genPending = make(chan struct{})
+ base.genAbort = make(chan chan *generatorStats)
+
+ var origin uint64
+
+ if len(generator.Marker) >= 8 {
+ origin = binary.BigEndian.Uint64(generator.Marker)
+ }
+
+ go base.generate(&generatorStats{
+ origin: origin,
+ start: time.Now(),
+ accounts: generator.Accounts,
+ slots: generator.Slots,
+ storage: types.StorageSize(generator.Storage),
+ logger: logger,
+ generateMetrics: metrics,
+ })
+ }
+
+ return snapshot, false, nil
+}
+
+// loadAndParseJournal tries to parse the snapshot journal in latest format.
+func loadAndParseJournal(
+ logger kvdb.Logger,
+ db kvdb.KVBatchStorage,
+ base *diskLayer,
+) (snapshot, journalGenerator, error) {
+ // Retrieve the disk layer generator. It must exist, no matter the
+ // snapshot is fully generated or not. Otherwise the entire disk
+ // layer is invalid.
+ generatorBlob := rawdb.ReadSnapshotGenerator(db)
+ if len(generatorBlob) == 0 {
+ return nil, journalGenerator{}, errors.New("missing snapshot generator")
+ }
+
+ var generator journalGenerator
+ if err := rlp.DecodeBytes(generatorBlob, &generator); err != nil {
+ return nil, journalGenerator{}, fmt.Errorf("failed to decode snapshot generator: %w", err)
+ }
+
+ // Retrieve the diff layer journal. It's possible that the journal is
+ // not existent, e.g. the disk layer is generating while that the Geth
+ // crashes without persisting the diff journal.
+ // So if there is no journal, or the journal is invalid(e.g. the journal
+ // is not matched with disk layer; or the it's the legacy-format journal,
+ // etc.), we just discard all diffs and try to recover them later.
+ var current snapshot = base
+
+ err := iterateJournal(logger, db,
+ func(
+ parent types.Hash,
+ root types.Hash,
+ destructSet map[types.Hash]struct{},
+ accountData map[types.Hash][]byte,
+ storageData map[types.Hash]map[types.Hash][]byte,
+ ) error {
+ current = newDiffLayer(current, root, destructSet, accountData, storageData, logger, base.snapmetrics)
+
+ return nil
+ },
+ )
+ if err != nil {
+ return base, generator, nil
+ }
+
+ return current, generator, nil
+}
+
+// journalCallback is a function which is invoked by iterateJournal, every
+// time a difflayer is loaded from disk.
+type journalCallback = func(
+ parent types.Hash,
+ root types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+) error
+
+// iterateJournal iterates through the journalled difflayers, loading them from
+// the database, and invoking the callback for each loaded layer.
+// The order is incremental; starting with the bottom-most difflayer, going towards
+// the most recent layer.
+// This method returns error either if there was some error reading from disk,
+// OR if the callback returns an error when invoked.
+func iterateJournal(logger kvdb.Logger, db kvdb.KVBatchStorage, callback journalCallback) error {
+ journal := rawdb.ReadSnapshotJournal(db)
+ if len(journal) == 0 {
+ logger.Warn("Loaded snapshot journal", "diffs", "missing")
+
+ return nil
+ }
+
+ r := rlp.NewStream(bytes.NewReader(journal), 0)
+
+ // Firstly, resolve the first element as the journal version
+ version, err := r.Uint64()
+ if err != nil {
+ logger.Warn("Failed to resolve the journal version", "error", err)
+
+ return errors.New("failed to resolve journal version")
+ }
+
+ if version != journalVersion {
+ logger.Warn("Discarded the snapshot journal with wrong version", "required", journalVersion, "got", version)
+
+ return errors.New("wrong journal version")
+ }
+
+ // Secondly, resolve the disk layer root, ensure it's continuous
+ // with disk layer. Note now we can ensure it's the snapshot journal
+ // correct version, so we expect everything can be resolved properly.
+ var parent types.Hash
+
+ if err := r.Decode(&parent); err != nil {
+ return errors.New("missing disk layer root")
+ }
+
+ if baseRoot := rawdb.ReadSnapshotRoot(db); baseRoot != parent {
+ logger.Warn("Loaded snapshot journal", "diskroot", baseRoot, "diffs", "unmatched")
+
+ return fmt.Errorf("mismatched disk and diff layers")
+ }
+
+ for {
+ var (
+ root types.Hash
+ destructs []journalDestruct
+ accounts []journalAccount
+ storage []journalStorage
+ destructSet = make(map[types.Hash]struct{})
+ accountData = make(map[types.Hash][]byte)
+ storageData = make(map[types.Hash]map[types.Hash][]byte)
+ )
+
+ // Read the next diff journal entry
+ if err := r.Decode(&root); err != nil {
+ // The first read may fail with EOF, marking the end of the journal
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+
+ return fmt.Errorf("load diff root: %w", err)
+ }
+
+ if err := r.Decode(&destructs); err != nil {
+ return fmt.Errorf("load diff destructs: %w", err)
+ }
+
+ if err := r.Decode(&accounts); err != nil {
+ return fmt.Errorf("load diff accounts: %w", err)
+ }
+
+ if err := r.Decode(&storage); err != nil {
+ return fmt.Errorf("load diff storage: %w", err)
+ }
+
+ for _, entry := range destructs {
+ destructSet[entry.Hash] = struct{}{}
+ }
+
+ for _, entry := range accounts {
+ if len(entry.Blob) > 0 { // RLP loses nil-ness, but `[]byte{}` is not a valid item, so reinterpret that
+ accountData[entry.Hash] = entry.Blob
+ } else {
+ accountData[entry.Hash] = nil
+ }
+ }
+
+ for _, entry := range storage {
+ slots := make(map[types.Hash][]byte)
+
+ for i, key := range entry.Keys {
+ if len(entry.Vals[i]) > 0 { // RLP loses nil-ness, but `[]byte{}` is not a valid item, so reinterpret that
+ slots[key] = entry.Vals[i]
+ } else {
+ slots[key] = nil
+ }
+ }
+
+ storageData[entry.Hash] = slots
+ }
+
+ if err := callback(parent, root, destructSet, accountData, storageData); err != nil {
+ return err
+ }
+
+ parent = root
+ }
+}
diff --git a/state/snapshot/metrics.go b/state/snapshot/metrics.go
new file mode 100644
index 0000000000..985a435f98
--- /dev/null
+++ b/state/snapshot/metrics.go
@@ -0,0 +1,330 @@
+package snapshot
+
+import (
+ "strings"
+
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+const (
+ _subsystemID = "snapshot"
+)
+
+type generateMetricContext struct {
+ generateSeconds metrics.DurationContext
+ accountProve metrics.DurationContext
+ accountTrieRead metrics.DurationContext
+ accountSnapRead metrics.DurationContext
+ accountWrite metrics.DurationContext
+ storageProve metrics.DurationContext
+ storageTrieRead metrics.DurationContext
+ storageSnapRead metrics.DurationContext
+ storageWrite metrics.DurationContext
+ storageClean metrics.DurationContext
+}
+
+func (ctx *generateMetricContext) Start() {
+ ctx.generateSeconds.Start()
+}
+
+type generateMetrics struct {
+ generatedAccountCount prometheus.Counter
+ recoveredAccountCount prometheus.Counter
+ wipedAccountCount prometheus.Counter
+ missallAccountCount prometheus.Counter
+ generatedStorageCount prometheus.Counter
+ recoveredStorageCount prometheus.Counter
+ wipedStorageCount prometheus.Counter
+ missallStorageCount prometheus.Counter
+ danglingStorageCount prometheus.Counter
+ successfulRangeProofCount prometheus.Counter
+ failedRangeProofCount prometheus.Counter
+ usedSeconds prometheus.Gauge
+ estimateSeconds prometheus.Gauge
+
+ // accountProveNanoseconds measures time spent on the account proving
+ accountProveNanoseconds prometheus.Histogram
+ // accountTrieReadNanoseconds measures time spent on the account trie iteration
+ accountTrieReadNanoseconds prometheus.Histogram
+ // accountSnapReadNanoseconds measures time spent on the snapshot account iteration
+ accountSnapReadNanoseconds prometheus.Histogram
+ // accountWriteNanoseconds measures time spent on writing/updating/deleting accounts
+ accountWriteNanoseconds prometheus.Histogram
+ // storageProveNanoseconds measures time spent on storage proving
+ storageProveNanoseconds prometheus.Histogram
+ // storageTrieReadNanoseconds measures time spent on the storage trie iteration
+ storageTrieReadNanoseconds prometheus.Histogram
+ // storageSnapReadNanoseconds measures time spent on the snapshot storage iteration
+ storageSnapReadNanoseconds prometheus.Histogram
+ // storageWriteNanoseconds measures time spent on writing/updating storages
+ storageWriteNanoseconds prometheus.Histogram
+ // storageCleanNanoseconds measures time spent on deleting storages
+ storageCleanNanoseconds prometheus.Histogram
+}
+
+func newGenerateMetrics(namespace string, constLabels prometheus.Labels) *generateMetrics {
+ var (
+ generatedAccountCount = newCounter(namespace, "generate_generated_account_count", constLabels)
+ recoveredAccountCount = newCounter(namespace, "generate_recovered_account_count", constLabels)
+ wipedAccountCount = newCounter(namespace, "generate_wiped_account_count", constLabels)
+ missallAccountCount = newCounter(namespace, "generate_missall_account_count", constLabels)
+ generatedStorageCount = newCounter(namespace, "generate_generated_storage_count", constLabels)
+ recoveredStorageCount = newCounter(namespace, "generate_recovered_storage_count", constLabels)
+ wipedStorageCount = newCounter(namespace, "generate_wiped_storage_count", constLabels)
+ missallStorageCount = newCounter(namespace, "generate_missall_storage_count", constLabels)
+ danglingStorageCount = newCounter(namespace, "generate_dangling_storage_size", constLabels)
+ successfulRangeProofCount = newCounter(namespace, "generate_successful_range_proof_count", constLabels)
+ failedRangeProofCount = newCounter(namespace, "generate_failed_range_proof_count", constLabels)
+ usedSeconds = newGauge(namespace, "generate_used_seconds", constLabels)
+ estimateSeconds = newGauge(namespace, "generate_estimate_seconds", constLabels)
+ // all nanoseconds metrics
+ accountProveNanoseconds = newHistogram(namespace, "generate_account_prove_nanoseconds", constLabels)
+ accountTrieReadNanoSeconds = newHistogram(namespace, "generate_account_trie_read_nanoseconds", constLabels)
+ accountSnapReadNanoseconds = newHistogram(namespace, "generate_account_snap_read_nanoseconds", constLabels)
+ accountWriteNanoseconds = newHistogram(namespace, "generate_account_write_nanoseconds", constLabels)
+ storageProveNanoseconds = newHistogram(namespace, "generate_storage_prove_nanoseconds", constLabels)
+ storageTrieReadNanoseconds = newHistogram(namespace, "generate_storage_trie_read_nanoseconds", constLabels)
+ storageSnapReadNanoseconds = newHistogram(namespace, "generate_storage_snap_read_nanoseconds", constLabels)
+ storageWriteNanoseconds = newHistogram(namespace, "generate_storage_write_nanoseconds", constLabels)
+ storageCleanNanoseconds = newHistogram(namespace, "generate_storage_clean_nanoseconds", constLabels)
+ )
+
+ prometheus.MustRegister(generatedAccountCount)
+ prometheus.MustRegister(recoveredAccountCount)
+ prometheus.MustRegister(wipedAccountCount)
+ prometheus.MustRegister(missallAccountCount)
+ prometheus.MustRegister(generatedStorageCount)
+ prometheus.MustRegister(recoveredStorageCount)
+ prometheus.MustRegister(wipedStorageCount)
+ prometheus.MustRegister(missallStorageCount)
+ prometheus.MustRegister(danglingStorageCount)
+ prometheus.MustRegister(successfulRangeProofCount)
+ prometheus.MustRegister(failedRangeProofCount)
+ prometheus.MustRegister(usedSeconds)
+ prometheus.MustRegister(estimateSeconds)
+ prometheus.MustRegister(accountProveNanoseconds)
+ prometheus.MustRegister(accountTrieReadNanoSeconds)
+ prometheus.MustRegister(accountSnapReadNanoseconds)
+ prometheus.MustRegister(accountWriteNanoseconds)
+ prometheus.MustRegister(storageProveNanoseconds)
+ prometheus.MustRegister(storageTrieReadNanoseconds)
+ prometheus.MustRegister(storageSnapReadNanoseconds)
+ prometheus.MustRegister(storageWriteNanoseconds)
+ prometheus.MustRegister(storageCleanNanoseconds)
+
+ return &generateMetrics{
+ generatedAccountCount: generatedAccountCount,
+ recoveredAccountCount: recoveredAccountCount,
+ wipedAccountCount: wipedAccountCount,
+ missallAccountCount: missallAccountCount,
+ generatedStorageCount: generatedStorageCount,
+ recoveredStorageCount: recoveredStorageCount,
+ wipedStorageCount: wipedStorageCount,
+ missallStorageCount: missallStorageCount,
+ danglingStorageCount: danglingStorageCount,
+ successfulRangeProofCount: successfulRangeProofCount,
+ failedRangeProofCount: failedRangeProofCount,
+ usedSeconds: usedSeconds,
+ estimateSeconds: estimateSeconds,
+ accountProveNanoseconds: accountProveNanoseconds,
+ accountTrieReadNanoseconds: accountTrieReadNanoSeconds,
+ accountSnapReadNanoseconds: accountSnapReadNanoseconds,
+ accountWriteNanoseconds: accountWriteNanoseconds,
+ storageProveNanoseconds: storageProveNanoseconds,
+ storageTrieReadNanoseconds: storageTrieReadNanoseconds,
+ storageSnapReadNanoseconds: storageSnapReadNanoseconds,
+ storageWriteNanoseconds: storageWriteNanoseconds,
+ storageCleanNanoseconds: storageCleanNanoseconds,
+ }
+}
+
+func nilGenerateMetrics() *generateMetrics {
+ return &generateMetrics{}
+}
+
+func (m *generateMetrics) Context() *generateMetricContext {
+ return &generateMetricContext{
+ generateSeconds: metrics.NewDurationContextWithUnit(metrics.DurationSecond),
+ accountProve: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ accountTrieRead: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ accountSnapRead: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ accountWrite: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ storageProve: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ storageTrieRead: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ storageSnapRead: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ storageWrite: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ storageClean: metrics.NewDurationContextWithUnit(metrics.DurationNanosecond),
+ }
+}
+
+type Metrics struct {
+ *generateMetrics
+
+ cleanAccountHitCount prometheus.Counter
+ cleanAccountMissCount prometheus.Counter
+ cleanAccountInexCount prometheus.Counter
+ cleanAccountReadSize prometheus.Histogram
+ cleanAccountWriteSize prometheus.Histogram
+
+ cleanStorageHitCount prometheus.Counter
+ cleanStorageMissCount prometheus.Counter
+ cleanStorageInexCount prometheus.Counter
+ cleanStorageReadSize prometheus.Histogram
+ cleanStorageWriteSize prometheus.Histogram
+
+ dirtyAccountHitCount prometheus.Counter
+ dirtyAccountMissCount prometheus.Counter
+ dirtyAccountInexCount prometheus.Counter
+ dirtyAccountReadSize prometheus.Histogram
+ dirtyAccountWriteSize prometheus.Histogram
+
+ dirtyStorageHitCount prometheus.Counter
+ dirtyStorageMissCount prometheus.Counter
+ dirtyStorageInexCount prometheus.Counter
+ dirtyStorageReadSize prometheus.Histogram
+ dirtyStorageWriteSize prometheus.Histogram
+
+ dirtyAccountHitDepth prometheus.Histogram
+ dirtyStorageHitDepth prometheus.Histogram
+
+ flushAccountItemCount prometheus.Counter
+ flushAccountSize prometheus.Histogram
+ flushStorageItemCount prometheus.Counter
+ flushStorageSize prometheus.Histogram
+
+ bloomIndexNanoseconds prometheus.Histogram
+ bloomErrorCount prometheus.Gauge
+
+ bloomAccountTrueHitCount prometheus.Counter
+ bloomAccountFalseHitCount prometheus.Counter
+ bloomAccountMissCount prometheus.Counter
+
+ bloomStorageTrueHitCount prometheus.Counter
+ bloomStorageFalseHitCount prometheus.Counter
+ bloomStorageMissCount prometheus.Counter
+}
+
+// GetPrometheusMetrics return the snapshot metrics instance
+func GetPrometheusMetrics(namespace string, constLabelsWithValues ...string) *Metrics {
+ constLabels := metrics.ParseLables(constLabelsWithValues...)
+
+ m := &Metrics{
+ cleanAccountHitCount: newCounter(namespace, "clean_account_hit_count", constLabels),
+ cleanAccountMissCount: newCounter(namespace, "clean_account_miss_count", constLabels),
+ cleanAccountInexCount: newCounter(namespace, "clean_account_inex_count", constLabels),
+ cleanAccountReadSize: newHistogram(namespace, "clean_account_read_size", constLabels),
+ cleanAccountWriteSize: newHistogram(namespace, "clean_account_write_size", constLabels),
+ cleanStorageHitCount: newCounter(namespace, "clean_storage_hit_count", constLabels),
+ cleanStorageMissCount: newCounter(namespace, "clean_storage_miss_count", constLabels),
+ cleanStorageInexCount: newCounter(namespace, "clean_storage_inex_count", constLabels),
+ cleanStorageReadSize: newHistogram(namespace, "clean_storage_read_size", constLabels),
+ cleanStorageWriteSize: newHistogram(namespace, "clean_storage_write_size", constLabels),
+ dirtyAccountHitCount: newCounter(namespace, "dirty_account_hit_count", constLabels),
+ dirtyAccountMissCount: newCounter(namespace, "dirty_account_miss_count", constLabels),
+ dirtyAccountInexCount: newCounter(namespace, "dirty_account_inex_count", constLabels),
+ dirtyAccountReadSize: newHistogram(namespace, "dirty_account_read_size", constLabels),
+ dirtyAccountWriteSize: newHistogram(namespace, "dirty_account_write_size", constLabels),
+ dirtyStorageHitCount: newCounter(namespace, "dirty_storage_hit_count", constLabels),
+ dirtyStorageMissCount: newCounter(namespace, "dirty_storage_miss_count", constLabels),
+ dirtyStorageInexCount: newCounter(namespace, "dirty_storage_inex_count", constLabels),
+ dirtyStorageReadSize: newHistogram(namespace, "dirty_storage_read_size", constLabels),
+ dirtyStorageWriteSize: newHistogram(namespace, "dirty_storage_write_size", constLabels),
+ dirtyAccountHitDepth: newHistogram(namespace, "dirty_account_hit_depth", constLabels),
+ dirtyStorageHitDepth: newHistogram(namespace, "dirty_storage_hit_depth", constLabels),
+ flushAccountItemCount: newCounter(namespace, "flush_account_item_count", constLabels),
+ flushAccountSize: newHistogram(namespace, "flush_account_size", constLabels),
+ flushStorageItemCount: newCounter(namespace, "flush_storage_item_count", constLabels),
+ flushStorageSize: newHistogram(namespace, "flush_storage_size", constLabels),
+ bloomIndexNanoseconds: newHistogram(namespace, "bloom_index_nanoseconds", constLabels),
+ bloomErrorCount: newGauge(namespace, "bloom_error_count", constLabels),
+ bloomAccountTrueHitCount: newCounter(namespace, "bloom_account_true_hit_count", constLabels),
+ bloomAccountFalseHitCount: newCounter(namespace, "bloom_account_false_hit_count", constLabels),
+ bloomAccountMissCount: newCounter(namespace, "bloom_account_miss_count", constLabels),
+ bloomStorageTrueHitCount: newCounter(namespace, "bloom_storage_true_hit_count", constLabels),
+ bloomStorageFalseHitCount: newCounter(namespace, "bloom_storage_false_hit_count", constLabels),
+ bloomStorageMissCount: newCounter(namespace, "bloom_storage_miss_count", constLabels),
+ }
+
+ m.generateMetrics = newGenerateMetrics(namespace, constLabels)
+
+ prometheus.MustRegister(
+ m.cleanAccountHitCount,
+ m.cleanAccountMissCount,
+ m.cleanAccountInexCount,
+ m.cleanAccountReadSize,
+ m.cleanAccountWriteSize,
+ m.cleanStorageHitCount,
+ m.cleanStorageMissCount,
+ m.cleanStorageInexCount,
+ m.cleanStorageReadSize,
+ m.cleanStorageWriteSize,
+ m.dirtyAccountHitCount,
+ m.dirtyAccountMissCount,
+ m.dirtyAccountInexCount,
+ m.dirtyAccountReadSize,
+ m.dirtyAccountWriteSize,
+ m.dirtyStorageHitCount,
+ m.dirtyStorageMissCount,
+ m.dirtyStorageInexCount,
+ m.dirtyStorageReadSize,
+ m.dirtyStorageWriteSize,
+ m.dirtyAccountHitDepth,
+ m.dirtyStorageHitDepth,
+ m.flushAccountItemCount,
+ m.flushAccountSize,
+ m.flushStorageItemCount,
+ m.flushStorageSize,
+ m.bloomIndexNanoseconds,
+ m.bloomErrorCount,
+ m.bloomAccountTrueHitCount,
+ m.bloomAccountFalseHitCount,
+ m.bloomAccountMissCount,
+ m.bloomStorageTrueHitCount,
+ m.bloomStorageFalseHitCount,
+ m.bloomStorageMissCount,
+ )
+
+ return m
+}
+
+func metricName2Help(name string) string {
+ return strings.ReplaceAll(name, "_", " ")
+}
+
+func newGauge(namespace, name string, constLabels prometheus.Labels) prometheus.Gauge {
+ return prometheus.NewGauge(prometheus.GaugeOpts{
+ Namespace: namespace,
+ Subsystem: _subsystemID,
+ Name: name,
+ Help: metricName2Help(name),
+ ConstLabels: constLabels,
+ })
+}
+
+func newCounter(namespace, name string, constLabels prometheus.Labels) prometheus.Counter {
+ return prometheus.NewCounter(prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: _subsystemID,
+ Name: name,
+ Help: metricName2Help(name),
+ ConstLabels: constLabels,
+ })
+}
+
+func newHistogram(namespace, name string, constLabels prometheus.Labels) prometheus.Histogram {
+ return prometheus.NewHistogram(prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: _subsystemID,
+ Name: name,
+ Help: metricName2Help(name),
+ ConstLabels: constLabels,
+ })
+}
+
+// NilMetrics will return the non operational snapshot metrics
+func NilMetrics() *Metrics {
+ return &Metrics{
+ generateMetrics: nilGenerateMetrics(),
+ }
+}
diff --git a/state/snapshot/snapshot.go b/state/snapshot/snapshot.go
new file mode 100644
index 0000000000..df946f6b7f
--- /dev/null
+++ b/state/snapshot/snapshot.go
@@ -0,0 +1,907 @@
+package snapshot
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "os"
+ "sync"
+ "sync/atomic"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/metrics"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/trie"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+var (
+ // ErrSnapshotStale is returned from data accessors if the underlying snapshot
+ // layer had been invalidated due to the chain progressing forward far enough
+ // to not maintain the layer's original state.
+ ErrSnapshotStale = errors.New("snapshot stale")
+
+ // ErrNotCoveredYet is returned from data accessors if the underlying snapshot
+ // is being generated currently and the requested data item is not yet in the
+ // range of accounts covered.
+ ErrNotCoveredYet = errors.New("not covered yet")
+
+ // ErrNotConstructed is returned if the callers want to iterate the snapshot
+ // while the generation is not finished yet.
+ ErrNotConstructed = errors.New("snapshot is not constructed")
+
+ // errSnapshotCycle is returned if a snapshot is attempted to be inserted
+ // that forms a cycle in the snapshot tree.
+ errSnapshotCycle = errors.New("snapshot cycle")
+)
+
+// Snapshot represents the functionality supported by a snapshot storage layer.
+type Snapshot interface {
+ // Root returns the root hash for which this snapshot was made.
+ Root() types.Hash
+
+ // Account directly retrieves the account associated with a particular hash in
+ // the snapshot slim data format.
+ Account(hash types.Hash) (*stypes.Account, error)
+
+ // AccountRLP directly retrieves the account RLP associated with a particular
+ // hash in the snapshot slim data format.
+ AccountRLP(hash types.Hash) ([]byte, error)
+
+ // Storage directly retrieves the storage data associated with a particular hash,
+ // within a particular account.
+ Storage(accountHash, storageHash types.Hash) ([]byte, error)
+}
+
+// snapshot is the internal version of the snapshot data layer that supports some
+// additional methods compared to the public API.
+type snapshot interface {
+ Snapshot
+
+ // Parent returns the subsequent layer of a snapshot, or nil if the base was
+ // reached.
+ //
+ // Note, the method is an internal helper to avoid type switching between the
+ // disk and diff layers. There is no locking involved.
+ Parent() snapshot
+
+ // Update creates a new layer on top of the existing snapshot diff tree with
+ // the specified data items.
+ //
+ // Note, the maps are retained by the method to avoid copying everything.
+ Update(
+ blockRoot types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ logger kvdb.Logger,
+ ) *diffLayer
+
+ // Journal commits an entire diff hierarchy to disk into a single journal entry.
+ // This is meant to be used during shutdown to persist the snapshot without
+ // flattening everything down (bad for reorgs).
+ Journal(buffer *bytes.Buffer) (types.Hash, error)
+
+ // Stale return whether this layer has become stale (was flattened across) or
+ // if it's still live.
+ Stale() bool
+
+ // AccountIterator creates an account iterator over an arbitrary layer.
+ AccountIterator(seek types.Hash) AccountIterator
+
+ // StorageIterator creates a storage iterator over an arbitrary layer.
+ StorageIterator(account types.Hash, seek types.Hash) (StorageIterator, bool)
+}
+
+// Config includes the configurations for snapshots.
+type Config struct {
+ CacheSize int // Megabytes permitted to use for read caches
+ Recovery bool // Indicator that the snapshots is in the recovery mode
+ NoBuild bool // Indicator that the snapshots generation is disallowed
+ AsyncBuild bool // The snapshot generation is allowed to be constructed asynchronously
+}
+
+// Tree is an Ethereum state snapshot tree. It consists of one persistent base
+// layer backed by a key-value store, on top of which arbitrarily many in-memory
+// diff layers are topped. The memory diffs can form a tree with branching, but
+// the disk layer is singleton and common to all. If a reorg goes deeper than the
+// disk layer, everything needs to be deleted.
+//
+// The goal of a state snapshot is twofold: to allow direct access to account and
+// storage data to avoid expensive multi-level trie lookups; and to allow sorted,
+// cheap iteration of the account/storage tries for sync aid.
+type Tree struct {
+ config Config // Snapshots configurations
+ diskdb kvdb.KVBatchStorage // Persistent database to store the snapshot
+ triedb *trie.Database // In-memory cache to access the trie through
+ layers map[types.Hash]snapshot // Collection of all known layers
+ lock sync.RWMutex
+
+ logger kvdb.Logger
+ snapmetrics *Metrics
+
+ // Test hooks
+ onFlatten func() // Hook invoked when the bottom most diff layers are flattened
+}
+
+// New attempts to load an already existing snapshot from a persistent key-value
+// store (with a number of memory layers from a journal), ensuring that the head
+// of the snapshot matches the expected one.
+//
+// If the snapshot is missing or the disk layer is broken, the snapshot will be
+// reconstructed using both the existing data and the state trie.
+// The repair happens on a background thread.
+//
+// If the memory layers in the journal do not match the disk layer (e.g. there is
+// a gap) or the journal is missing, there are two repair cases:
+//
+// - if the 'recovery' parameter is true, memory diff-layers and the disk-layer
+// will all be kept. This case happens when the snapshot is 'ahead' of the
+// state trie.
+// - otherwise, the entire snapshot is considered invalid and will be recreated on
+// a background thread.
+func New(
+ config Config,
+ diskdb kvdb.KVBatchStorage,
+ triedb *trie.Database,
+ root types.Hash,
+ logger kvdb.Logger,
+ snapmetrics *Metrics,
+) (*Tree, error) {
+ // Create a new, empty snapshot tree
+ snap := &Tree{
+ config: config,
+ diskdb: diskdb,
+ triedb: triedb,
+ layers: make(map[types.Hash]snapshot),
+ logger: logger,
+ snapmetrics: snapmetrics,
+ }
+
+ // Attempt to load a previously persisted snapshot and rebuild one if failed
+ head, disabled, err := loadSnapshot(logger, snapmetrics, diskdb, triedb,
+ root, config.CacheSize, config.Recovery, config.NoBuild)
+ if disabled {
+ snap.logger.Warn("Snapshot maintenance disabled (syncing)")
+
+ return snap, nil
+ }
+
+ // Create the building waiter iff the background generation is allowed
+ if !config.NoBuild && !config.AsyncBuild {
+ defer snap.waitBuild()
+ }
+
+ if err != nil {
+ snap.logger.Warn("Failed to load snapshot", "err", err)
+
+ if !config.NoBuild {
+ snap.Rebuild(root)
+
+ return snap, nil
+ }
+
+ return nil, err // Bail out the error, don't rebuild automatically.
+ }
+
+ // Existing snapshot loaded, seed all the layers
+ for head != nil {
+ snap.layers[head.Root()] = head
+ head = head.Parent()
+ }
+
+ return snap, nil
+}
+
+// waitBuild blocks until the snapshot finishes rebuilding. This method is meant
+// to be used by tests to ensure we're testing what we believe we are.
+func (t *Tree) waitBuild() {
+ // Find the rebuild termination channel
+ var done chan struct{}
+
+ t.lock.RLock()
+ for _, layer := range t.layers {
+ if layer, ok := layer.(*diskLayer); ok {
+ done = layer.genPending
+
+ break
+ }
+ }
+ t.lock.RUnlock()
+
+ // Wait until the snapshot is generated
+ if done != nil {
+ <-done
+ }
+}
+
+// Disable interrupts any pending snapshot generator, deletes all the snapshot
+// layers in memory and marks snapshots disabled globally. In order to resume
+// the snapshot functionality, the caller must invoke Rebuild.
+func (t *Tree) Disable() {
+ // Interrupt any live snapshot layers
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ for _, layer := range t.layers {
+ switch layer := layer.(type) {
+ case *diskLayer:
+ // If the base layer is generating, abort it
+ if layer.genAbort != nil {
+ abort := make(chan *generatorStats)
+ layer.genAbort <- abort
+ <-abort
+ }
+ // Layer should be inactive now, mark it as stale
+ layer.lock.Lock()
+ layer.stale = true
+ layer.lock.Unlock()
+ case *diffLayer:
+ // If the layer is a simple diff, simply mark as stale
+ layer.lock.Lock()
+ atomic.StoreUint32(&layer.stale, 1)
+ layer.lock.Unlock()
+ default:
+ panic(fmt.Sprintf("unknown layer type: %T", layer))
+ }
+ }
+
+ t.layers = map[types.Hash]snapshot{}
+
+ // Delete all snapshot liveness information from the database
+ batch := t.diskdb.NewBatch()
+
+ // delete snapshot in batch
+ rawdb.WriteSnapshotDisabled(batch)
+ rawdb.DeleteSnapshotRoot(batch)
+ rawdb.DeleteSnapshotJournal(batch)
+ rawdb.DeleteSnapshotGenerator(batch)
+ rawdb.DeleteSnapshotRecoveryNumber(batch)
+
+ // Note, we don't delete the sync progress
+ if err := batch.Write(); err != nil {
+ t.logger.Error("Failed to disable snapshots", "err", err)
+ os.Exit(1)
+ }
+}
+
+// Snapshot retrieves a snapshot belonging to the given block root, or nil if no
+// snapshot is maintained for that block.
+func (t *Tree) Snapshot(blockRoot types.Hash) Snapshot {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ return t.layers[blockRoot]
+}
+
+// Snapshots returns all visited layers from the topmost layer with specific
+// root and traverses downward. The layer amount is limited by the given number.
+// If nodisk is set, then disk layer is excluded.
+func (t *Tree) Snapshots(root types.Hash, limits int, nodisk bool) []Snapshot {
+ t.lock.RLock()
+ defer t.lock.RUnlock()
+
+ if limits == 0 {
+ return nil
+ }
+
+ layer := t.layers[root]
+ if layer == nil {
+ return nil
+ }
+
+ var ret []Snapshot
+
+ for {
+ if _, isdisk := layer.(*diskLayer); isdisk && nodisk {
+ break
+ }
+
+ ret = append(ret, layer)
+
+ limits -= 1
+ if limits == 0 {
+ break
+ }
+
+ parent := layer.Parent()
+ if parent == nil {
+ break
+ }
+
+ layer = parent
+ }
+
+ return ret
+}
+
+// Update adds a new snapshot into the tree, if that can be linked to an existing
+// old parent. It is disallowed to insert a disk layer (the origin of all).
+func (t *Tree) Update(
+ blockRoot types.Hash,
+ parentRoot types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ logger kvdb.Logger,
+) error {
+ // Reject noop updates to avoid self-loops in the snapshot tree. This is a
+ // special case that can only happen for Clique networks where empty blocks
+ // don't modify the state (0 block subsidy).
+ //
+ // Although we could silently ignore this internally, it should be the caller's
+ // responsibility to avoid even attempting to insert such a snapshot.
+ if blockRoot == parentRoot {
+ return errSnapshotCycle
+ }
+
+ // Generate a new snapshot on top of the parent
+ parent := t.Snapshot(parentRoot)
+ if parent == nil {
+ return fmt.Errorf("parent [%s] snapshot missing", parentRoot)
+ }
+
+ //nolint:forcetypeassert
+ snap := parent.(snapshot).Update(blockRoot, destructs, accounts, storage, logger)
+
+ // Save the new snapshot for later
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ t.layers[snap.root] = snap
+
+ return nil
+}
+
+// Cap traverses downwards the snapshot tree from a head block hash until the
+// number of allowed layers are crossed. All layers beyond the permitted number
+// are flattened downwards.
+//
+// Note, the final diff layer count in general will be one more than the amount
+// requested. This happens because the bottom-most diff layer is the accumulator
+// which may or may not overflow and cascade to disk. Since this last layer's
+// survival is only known *after* capping, we need to omit it from the count if
+// we want to ensure that *at least* the requested number of diff layers remain.
+func (t *Tree) Cap(root types.Hash, layers int) error {
+ // Retrieve the head snapshot to cap from
+ snap := t.Snapshot(root)
+ if snap == nil {
+ return fmt.Errorf("snapshot [%s] missing", root)
+ }
+
+ diff, ok := snap.(*diffLayer)
+ if !ok {
+ return fmt.Errorf("snapshot [%s] is disk layer", root)
+ }
+
+ // If the generator is still running, use a more aggressive cap
+ diff.origin.lock.RLock()
+ if diff.origin.genMarker != nil && layers > 8 {
+ layers = 8
+ }
+
+ diff.origin.lock.RUnlock()
+
+ // Run the internal capping and discard all stale layers
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ // Flattening the bottom-most diff layer requires special casing since there's
+ // no child to rewire to the grandparent. In that case we can fake a temporary
+ // child for the capping and then remove it.
+ if layers == 0 {
+ // If full commit was requested, flatten the diffs and merge onto disk
+ diff.lock.RLock()
+
+ //nolint:forcetypeassert
+ base := diffToDisk(diff.flatten().(*diffLayer))
+
+ diff.lock.RUnlock()
+
+ // Replace the entire snapshot tree with the flat base
+ t.layers = map[types.Hash]snapshot{base.root: base}
+
+ return nil
+ }
+
+ persisted := t.cap(diff, layers)
+
+ // Remove any layer that is stale or links into a stale layer
+ children := make(map[types.Hash][]types.Hash)
+
+ for root, snap := range t.layers {
+ if diff, ok := snap.(*diffLayer); ok {
+ parent := diff.parent.Root()
+ children[parent] = append(children[parent], root)
+ }
+ }
+
+ var remove func(root types.Hash)
+ remove = func(root types.Hash) {
+ delete(t.layers, root)
+
+ for _, child := range children[root] {
+ remove(child)
+ }
+
+ delete(children, root)
+ }
+
+ for root, snap := range t.layers {
+ if snap.Stale() {
+ remove(root)
+ }
+ }
+
+ // If the disk layer was modified, regenerate all the cumulative blooms
+ if persisted != nil {
+ var rebloom func(root types.Hash)
+ rebloom = func(root types.Hash) {
+ if diff, ok := t.layers[root].(*diffLayer); ok {
+ diff.rebloom(persisted)
+ }
+
+ for _, child := range children[root] {
+ rebloom(child)
+ }
+ }
+
+ rebloom(persisted.root)
+ }
+
+ return nil
+}
+
+// cap traverses downwards the diff tree until the number of allowed layers are
+// crossed. All diffs beyond the permitted number are flattened downwards. If the
+// layer limit is reached, memory cap is also enforced (but not before).
+//
+// The method returns the new disk layer if diffs were persisted into it.
+//
+// Note, the final diff layer count in general will be one more than the amount
+// requested. This happens because the bottom-most diff layer is the accumulator
+// which may or may not overflow and cascade to disk. Since this last layer's
+// survival is only known *after* capping, we need to omit it from the count if
+// we want to ensure that *at least* the requested number of diff layers remain.
+func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
+ // Dive until we run out of layers or reach the persistent database
+ for i := 0; i < layers-1; i++ {
+ // If we still have diff layers below, continue down
+ if parent, ok := diff.parent.(*diffLayer); ok {
+ diff = parent
+ } else {
+ // Diff stack too shallow, return without modifications
+ return nil
+ }
+ }
+ // We're out of layers, flatten anything below, stopping if it's the disk or if
+ // the memory limit is not yet exceeded.
+ switch parent := diff.parent.(type) {
+ case *diskLayer:
+ return nil
+
+ case *diffLayer:
+ // Hold the write lock until the flattened parent is linked correctly.
+ // Otherwise, the stale layer may be accessed by external reads in the
+ // meantime.
+ diff.lock.Lock()
+ defer diff.lock.Unlock()
+
+ // Flatten the parent into the grandparent. The flattening internally obtains a
+ // write lock on grandparent.
+ flattened, _ := parent.flatten().(*diffLayer)
+
+ t.layers[flattened.root] = flattened
+
+ // Invoke the hook if it's registered. Ugly hack.
+ if t.onFlatten != nil {
+ t.onFlatten()
+ }
+
+ diff.parent = flattened
+
+ if flattened.memory < aggregatorMemoryLimit {
+ // Accumulator layer is smaller than the limit, so we can abort, unless
+ // there's a snapshot being generated currently. In that case, the trie
+ // will move from underneath the generator so we **must** merge all the
+ // partial data down into the snapshot and restart the generation.
+ //nolint:forcetypeassert
+ if flattened.parent.(*diskLayer).genAbort == nil {
+ return nil
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unknown data layer: %T", parent))
+ }
+
+ // If the bottom-most layer is larger than our memory cap, persist to disk
+ bottom, _ := diff.parent.(*diffLayer)
+
+ bottom.lock.RLock()
+
+ base := diffToDisk(bottom)
+
+ bottom.lock.RUnlock()
+
+ t.layers[base.root] = base
+ diff.parent = base
+
+ return base
+}
+
+// diffToDisk merges a bottom-most diff into the persistent disk layer underneath
+// it. The method will panic if called onto a non-bottom-most diff layer.
+//
+// The disk layer persistence should be operated in an atomic way. All updates should
+// be discarded if the whole transition if not finished.
+func diffToDisk(bottom *diffLayer) *diskLayer {
+ var (
+ //nolint:forcetypeassert
+ base = bottom.parent.(*diskLayer)
+ batch = base.diskdb.NewBatch()
+ stats *generatorStats
+ logger = bottom.logger
+ )
+
+ // If the disk layer is running a snapshot generator, abort it
+ if base.genAbort != nil {
+ abort := make(chan *generatorStats)
+
+ base.genAbort <- abort
+ stats = <-abort
+ }
+
+ // Put the deletion in the batch writer, flush all updates in the final step.
+ rawdb.DeleteSnapshotRoot(batch)
+
+ // Mark the original base as stale as we're going to create a new wrapper
+ base.lock.Lock()
+ if base.stale {
+ panic("parent disk layer is stale") // we've committed into the same base from two children, boo
+ }
+
+ base.stale = true
+
+ base.lock.Unlock()
+
+ // Destroy all the destructed accounts from the database
+ for hash := range bottom.destructSet {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 {
+ continue
+ }
+
+ // Remove all storage slots
+ rawdb.DeleteAccountSnapshot(batch, hash)
+ base.cache.Set(hash[:], nil)
+
+ it := rawdb.IterateStorageSnapshots(base.diskdb, hash)
+ for it.Next() {
+ key := it.Key()
+ batch.Delete(key)
+ // delete all cache snapshot key
+ base.cache.Del(key[rawdb.SnapshotPrefixLength:])
+ metrics.CounterInc(base.snapmetrics.flushStorageItemCount)
+
+ // Ensure we don't delete too much data blindly (contract can be
+ // huge). It's ok to flush, the root will go missing in case of a
+ // crash and we'll detect and regenerate the snapshot.
+ if batch.ValueSize() > kvdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ logger.Error("Failed to write storage deletions", "err", err)
+ os.Exit(1)
+ }
+
+ batch.Reset()
+ }
+ }
+ it.Release()
+ }
+
+ var (
+ cleanAccountWriteSize int64 = 0
+ flushAccountSize int64 = 0
+ )
+
+ // Push all updated accounts into the database
+ for hash, data := range bottom.accountData {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(hash[:], base.genMarker) > 0 {
+ continue
+ }
+
+ // Push the account to disk
+ rawdb.WriteAccountSnapshot(batch, hash, data)
+ base.cache.Set(hash[:], data)
+
+ // collect metrics
+ metrics.CounterInc(base.snapmetrics.flushAccountItemCount)
+ // whole data count
+ cleanAccountWriteSize += int64(len(data))
+ flushAccountSize += int64(len(data))
+
+ // Ensure we don't write too much data blindly. It's ok to flush, the
+ // root will go missing in case of a crash and we'll detect and regen
+ // the snapshot.
+ if batch.ValueSize() > kvdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ logger.Error("Failed to write storage deletions", "err", err)
+ os.Exit(1)
+ }
+
+ // reset batch for another write
+ batch.Reset()
+ }
+ }
+
+ metrics.HistogramObserve(base.snapmetrics.cleanAccountWriteSize, float64(cleanAccountWriteSize))
+ metrics.HistogramObserve(base.snapmetrics.flushAccountSize, float64(flushAccountSize))
+
+ var (
+ cleanStorageWriteSize int64 = 0
+ flushStorageSize int64 = 0
+ )
+
+ // Push all the storage slots into the database
+ for accountHash, storage := range bottom.storageData {
+ // Skip any account not covered yet by the snapshot
+ if base.genMarker != nil && bytes.Compare(accountHash[:], base.genMarker) > 0 {
+ continue
+ }
+
+ // Generation might be mid-account, track that case too
+ midAccount := base.genMarker != nil && bytes.Equal(accountHash[:], base.genMarker[:types.HashLength])
+
+ for storageHash, data := range storage {
+ // Skip any slot not covered yet by the snapshot
+ if midAccount && bytes.Compare(storageHash[:], base.genMarker[types.HashLength:]) > 0 {
+ continue
+ }
+
+ if len(data) > 0 {
+ rawdb.WriteStorageSnapshot(batch, accountHash, storageHash, data)
+ base.cache.Set(append(accountHash[:], storageHash[:]...), data)
+ cleanStorageWriteSize += int64(len(data))
+ } else {
+ rawdb.DeleteStorageSnapshot(batch, accountHash, storageHash)
+ base.cache.Set(append(accountHash[:], storageHash[:]...), nil)
+ }
+
+ // collection metrics
+ metrics.CounterInc(base.snapmetrics.flushStorageItemCount)
+ // whole data count
+ flushStorageSize += int64(len(data))
+ }
+ }
+
+ metrics.HistogramObserve(base.snapmetrics.cleanStorageWriteSize, float64(cleanStorageWriteSize))
+ metrics.HistogramObserve(base.snapmetrics.flushStorageSize, float64(flushStorageSize))
+
+ // Update the snapshot block marker and write any remainder data
+ rawdb.WriteSnapshotRoot(batch, bottom.root)
+
+ // Write out the generator progress marker and report
+ journalProgress(batch, base.genMarker, stats, logger)
+
+ // Flush all the updates in the single db operation. Ensure the
+ // disk layer transition is atomic.
+ if err := batch.Write(); err != nil {
+ logger.Error("Failed to write leftover snapshot", "err", err)
+ os.Exit(1)
+ }
+
+ logger.Debug("Journalled disk layer", "root", bottom.root, "complete", base.genMarker == nil)
+
+ res := &diskLayer{
+ root: bottom.root,
+ cache: base.cache,
+ diskdb: base.diskdb,
+ triedb: base.triedb,
+ genMarker: base.genMarker,
+ genPending: base.genPending,
+ logger: base.logger,
+ snapmetrics: base.snapmetrics,
+ }
+
+ // If snapshot generation hasn't finished yet, port over all the starts and
+ // continue where the previous round left off.
+ //
+ // Note, the `base.genAbort` comparison is not used normally, it's checked
+ // to allow the tests to play with the marker without triggering this path.
+ if base.genMarker != nil && base.genAbort != nil {
+ res.genMarker = base.genMarker
+ res.genAbort = make(chan chan *generatorStats)
+
+ go res.generate(stats)
+ }
+
+ return res
+}
+
+// Journal commits an entire diff hierarchy to disk into a single journal entry.
+// This is meant to be used during shutdown to persist the snapshot without
+// flattening everything down (bad for reorgs).
+//
+// The method returns the root hash of the base layer that needs to be persisted
+// to disk as a trie too to allow continuing any pending generation op.
+func (t *Tree) Journal(root types.Hash) (types.Hash, error) {
+ // Retrieve the head snapshot to journal from var snap snapshot
+ snap := t.Snapshot(root)
+ if snap == nil {
+ return types.Hash{}, fmt.Errorf("snapshot [%s] missing", root)
+ }
+
+ // Run the journaling
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ // Firstly write out the metadata of journal
+ journal := new(bytes.Buffer)
+ if err := rlp.Encode(journal, journalVersion); err != nil {
+ return types.Hash{}, err
+ }
+
+ diskroot := t.diskRoot()
+ if diskroot == (types.Hash{}) {
+ return types.Hash{}, errors.New("invalid disk root")
+ }
+
+ // Secondly write out the disk layer root, ensure the
+ // diff journal is continuous with disk.
+ if err := rlp.Encode(journal, diskroot); err != nil {
+ return types.Hash{}, err
+ }
+
+ // Finally write out the journal of each layer in reverse order.
+ base, err := snap.(snapshot).Journal(journal)
+ if err != nil {
+ return types.Hash{}, err
+ }
+
+ // Store the journal into the database and return
+ rawdb.WriteSnapshotJournal(t.diskdb, journal.Bytes())
+
+ return base, nil
+}
+
+// Rebuild wipes all available snapshot data from the persistent database and
+// discard all caches and diff layers. Afterwards, it starts a new snapshot
+// generator with the given root hash.
+func (t *Tree) Rebuild(root types.Hash) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ // Firstly delete any recovery flag in the database. Because now we are
+ // building a brand new snapshot. Also reenable the snapshot feature.
+ rawdb.DeleteSnapshotRecoveryNumber(t.diskdb)
+ rawdb.DeleteSnapshotDisabled(t.diskdb)
+
+ // Iterate over and mark all layers stale
+ for _, layer := range t.layers {
+ switch layer := layer.(type) {
+ case *diskLayer:
+ // If the base layer is generating, abort it and save
+ if layer.genAbort != nil {
+ abort := make(chan *generatorStats)
+ layer.genAbort <- abort
+ <-abort
+ }
+ // Layer should be inactive now, mark it as stale
+ layer.lock.Lock()
+ layer.stale = true
+ layer.lock.Unlock()
+ case *diffLayer:
+ // If the layer is a simple diff, simply mark as stale
+ layer.lock.Lock()
+ atomic.StoreUint32(&layer.stale, 1)
+ layer.lock.Unlock()
+ default:
+ panic(fmt.Sprintf("unknown layer type: %T", layer))
+ }
+ }
+ // Start generating a new snapshot from scratch on a background thread. The
+ // generator will run a wiper first if there's not one running right now.
+ t.logger.Info("Rebuilding state snapshot")
+
+ t.layers = map[types.Hash]snapshot{
+ root: generateSnapshot(t.diskdb, t.triedb, t.config.CacheSize, root, t.logger, t.snapmetrics),
+ }
+}
+
+// AccountIterator creates a new account iterator for the specified root hash and
+// seeks to a starting account hash.
+func (t *Tree) AccountIterator(root types.Hash, seek types.Hash) (AccountIterator, error) {
+ ok, err := t.generating()
+ if err != nil {
+ return nil, err
+ }
+
+ if ok {
+ return nil, ErrNotConstructed
+ }
+
+ return newFastAccountIterator(t, root, seek)
+}
+
+// StorageIterator creates a new storage iterator for the specified root hash and
+// account. The iterator will be move to the specific start position.
+func (t *Tree) StorageIterator(root types.Hash, account types.Hash, seek types.Hash) (StorageIterator, error) {
+ ok, err := t.generating()
+ if err != nil {
+ return nil, err
+ }
+
+ if ok {
+ return nil, ErrNotConstructed
+ }
+
+ return newFastStorageIterator(t, root, account, seek)
+}
+
+// Verify iterates the whole state(all the accounts as well as the corresponding storages)
+// with the specific root and compares the re-computed hash with the original one.
+func (t *Tree) Verify(root types.Hash) error {
+ return nil
+}
+
+// disklayer is an internal helper function to return the disk layer.
+// The lock of snapTree is assumed to be held already.
+func (t *Tree) disklayer() *diskLayer {
+ var snap snapshot
+ for _, s := range t.layers {
+ snap = s
+
+ break
+ }
+
+ if snap == nil {
+ return nil
+ }
+
+ switch layer := snap.(type) {
+ case *diskLayer:
+ return layer
+ case *diffLayer:
+ return layer.origin
+ default:
+ panic(fmt.Sprintf("%T: undefined layer", snap))
+ }
+}
+
+// diskRoot is a internal helper function to return the disk layer root.
+// The lock of snapTree is assumed to be held already.
+func (t *Tree) diskRoot() types.Hash {
+ disklayer := t.disklayer()
+ if disklayer == nil {
+ return types.Hash{}
+ }
+
+ return disklayer.Root()
+}
+
+// generating is an internal helper function which reports whether the snapshot
+// is still under the construction.
+func (t *Tree) generating() (bool, error) {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ layer := t.disklayer()
+ if layer == nil {
+ return false, errors.New("disk layer is missing")
+ }
+
+ layer.lock.RLock()
+ defer layer.lock.RUnlock()
+
+ return layer.genMarker != nil, nil
+}
+
+// DiskRoot is a external helper function to return the disk layer root.
+func (t *Tree) DiskRoot() types.Hash {
+ t.lock.Lock()
+ defer t.lock.Unlock()
+
+ return t.diskRoot()
+}
diff --git a/state/snapshot/snapshot_test.go b/state/snapshot/snapshot_test.go
new file mode 100644
index 0000000000..60989d0b1c
--- /dev/null
+++ b/state/snapshot/snapshot_test.go
@@ -0,0 +1,89 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "math/big"
+ "math/rand"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// randomHash generates a random blob of data and returns it as a hash.
+func randomHash() types.Hash {
+ var hash types.Hash
+ if n, err := rand.Read(hash[:]); n != types.HashLength || err != nil {
+ panic(err)
+ }
+
+ return hash
+}
+
+// randomAccount generates a random account and returns it RLP encoded.
+func randomAccount() []byte {
+ root := randomHash()
+
+ a := stypes.Account{
+ Balance: big.NewInt(rand.Int63()),
+ Nonce: rand.Uint64(),
+ StorageRoot: root,
+ CodeHash: types.EmptyRootHash.Bytes(),
+ }
+
+ data, _ := rlp.EncodeToBytes(a)
+
+ return data
+}
+
+// randomAccountSet generates a set of random accounts with the given strings as
+// the account address hashes.
+func randomAccountSet(hashes ...string) map[types.Hash][]byte {
+ accounts := make(map[types.Hash][]byte)
+ for _, hash := range hashes {
+ accounts[types.StringToHash(hash)] = randomAccount()
+ }
+
+ return accounts
+}
+
+// randomStorageSet generates a set of random slots with the given strings as
+// the slot addresses.
+func randomStorageSet(accounts []string, hashes [][]string, nilStorage [][]string) map[types.Hash]map[types.Hash][]byte {
+ storages := make(map[types.Hash]map[types.Hash][]byte)
+
+ for index, account := range accounts {
+ storages[types.StringToHash(account)] = make(map[types.Hash][]byte)
+
+ if index < len(hashes) {
+ hashes := hashes[index]
+ for _, hash := range hashes {
+ storages[types.StringToHash(account)][types.StringToHash(hash)] = randomHash().Bytes()
+ }
+ }
+
+ if index < len(nilStorage) {
+ nils := nilStorage[index]
+ for _, hash := range nils {
+ storages[types.StringToHash(account)][types.StringToHash(hash)] = nil
+ }
+ }
+ }
+
+ return storages
+}
diff --git a/state/snapshot/sort.go b/state/snapshot/sort.go
new file mode 100644
index 0000000000..ec93c57e79
--- /dev/null
+++ b/state/snapshot/sort.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// hashes is a helper to implement sort.Interface.
+type hashes []types.Hash
+
+// Len is the number of elements in the collection.
+func (hs hashes) Len() int { return len(hs) }
+
+// Less reports whether the element with index i should sort before the element
+// with index j.
+func (hs hashes) Less(i, j int) bool { return bytes.Compare(hs[i][:], hs[j][:]) < 0 }
+
+// Swap swaps the elements with indexes i and j.
+func (hs hashes) Swap(i, j int) { hs[i], hs[j] = hs[j], hs[i] }
diff --git a/state/snapshot/utils.go b/state/snapshot/utils.go
new file mode 100644
index 0000000000..de2873ee3c
--- /dev/null
+++ b/state/snapshot/utils.go
@@ -0,0 +1,211 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package snapshot
+
+import (
+ "bytes"
+ "fmt"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// CheckDanglingStorage iterates the snap storage data, and verifies that all
+// storage also has corresponding account data.
+func CheckDanglingStorage(chaindb kvdb.KVBatchStorage, logger kvdb.Logger) error {
+ if err := checkDanglingDiskStorage(chaindb, logger); err != nil {
+ logger.Error("Database check error", "err", err)
+ }
+
+ return checkDanglingMemStorage(chaindb, logger)
+}
+
+// checkDanglingDiskStorage checks if there is any 'dangling' storage data in the
+// disk-backed snapshot layer.
+func checkDanglingDiskStorage(chaindb kvdb.KVBatchStorage, logger kvdb.Logger) error {
+ var (
+ lastReport = time.Now()
+ start = time.Now()
+ lastKey []byte
+ it = rawdb.NewKeyLengthIterator(
+ chaindb.NewIterator(rawdb.SnapshotStoragePrefix, nil),
+ rawdb.SnapshotPrefixLength+2*types.HashLength,
+ )
+ )
+
+ logger.Info("Checking dangling snapshot disk storage")
+
+ defer it.Release()
+
+ for it.Next() {
+ k := it.Key()
+ accKey := k[rawdb.SnapshotPrefixLength : rawdb.SnapshotPrefixLength+types.HashLength]
+
+ if bytes.Equal(accKey, lastKey) {
+ // No need to look up for every slot
+ continue
+ }
+
+ lastKey = types.CopyBytes(accKey)
+
+ if time.Since(lastReport) > time.Second*8 {
+ logger.Info("Iterating snap storage",
+ "at", fmt.Sprintf("%#x", accKey),
+ "elapsed", types.PrettyDuration(time.Since(start)),
+ )
+
+ lastReport = time.Now()
+ }
+
+ if data := rawdb.ReadAccountSnapshot(chaindb, types.BytesToHash(accKey)); len(data) == 0 {
+ logger.Warn("Dangling storage - missing account",
+ "account", fmt.Sprintf("%#x", accKey),
+ "storagekey", fmt.Sprintf("%#x", k),
+ )
+
+ return fmt.Errorf("dangling snapshot storage account %#x", accKey)
+ }
+ }
+
+ logger.Info("Verified the snapshot disk storage", "time", types.PrettyDuration(time.Since(start)), "err", it.Error())
+
+ return nil
+}
+
+// checkDanglingMemStorage checks if there is any 'dangling' storage in the journalled
+// snapshot difflayers.
+func checkDanglingMemStorage(db kvdb.KVBatchStorage, logger kvdb.Logger) error {
+ start := time.Now()
+
+ logger.Info("Checking dangling journalled storage")
+
+ err := iterateJournal(logger, db,
+ func(
+ pRoot,
+ root types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ ) error {
+ for accHash := range storage {
+ if _, ok := accounts[accHash]; !ok {
+ logger.Error("Dangling storage - missing account", "account", fmt.Sprintf("%#x", accHash), "root", root)
+ }
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ logger.Info("Failed to resolve snapshot journal", "err", err)
+
+ return err
+ }
+
+ logger.Info("Verified the snapshot journalled storage", "time", types.PrettyDuration(time.Since(start)))
+
+ return nil
+}
+
+// CheckJournalAccount shows information about an account, from the disk layer and
+// up through the diff layers.
+func CheckJournalAccount(db kvdb.KVBatchStorage, hash types.Hash, logger kvdb.Logger) error {
+ // Look up the disk layer first
+ baseRoot := rawdb.ReadSnapshotRoot(db)
+ fmt.Printf("Disklayer: Root: %x\n", baseRoot)
+
+ if data := rawdb.ReadAccountSnapshot(db, hash); data != nil {
+ account := new(Account)
+
+ if err := rlp.DecodeBytes(data, account); err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("\taccount.nonce: %d\n", account.Nonce)
+ fmt.Printf("\taccount.balance: %x\n", account.Balance)
+ fmt.Printf("\taccount.root: %x\n", account.Root)
+ fmt.Printf("\taccount.codehash: %x\n", account.CodeHash)
+ }
+
+ // Check storage
+ {
+ it := rawdb.NewKeyLengthIterator(
+ db.NewIterator(append(rawdb.SnapshotStoragePrefix, hash.Bytes()...), nil),
+ rawdb.SnapshotPrefixLength+2*types.HashLength,
+ )
+ fmt.Printf("\tStorage:\n")
+
+ for it.Next() {
+ slot := it.Key()[33:]
+ fmt.Printf("\t\t%x: %x\n", slot, it.Value())
+ }
+
+ it.Release()
+ }
+
+ var depth = 0
+
+ return iterateJournal(logger, db,
+ func(
+ pRoot,
+ root types.Hash,
+ destructs map[types.Hash]struct{},
+ accounts map[types.Hash][]byte,
+ storage map[types.Hash]map[types.Hash][]byte,
+ ) error {
+ _, a := accounts[hash]
+ _, b := destructs[hash]
+ _, c := storage[hash]
+ depth++
+
+ if !a && !b && !c {
+ return nil
+ }
+
+ fmt.Printf("Disklayer+%d: Root: %x, parent %x\n", depth, root, pRoot)
+
+ if data, ok := accounts[hash]; ok {
+ account := new(Account)
+
+ if err := rlp.DecodeBytes(data, account); err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("\taccount.nonce: %d\n", account.Nonce)
+ fmt.Printf("\taccount.balance: %x\n", account.Balance)
+ fmt.Printf("\taccount.root: %x\n", account.Root)
+ fmt.Printf("\taccount.codehash: %x\n", account.CodeHash)
+ }
+
+ if _, ok := destructs[hash]; ok {
+ fmt.Printf("\t Destructed!")
+ }
+
+ if data, ok := storage[hash]; ok {
+ fmt.Printf("\tStorage\n")
+
+ for k, v := range data {
+ fmt.Printf("\t\t%x: %x\n", k, v)
+ }
+ }
+
+ return nil
+ })
+}
diff --git a/state/state.go b/state/state.go
deleted file mode 100644
index 274a4d3930..0000000000
--- a/state/state.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package state
-
-import (
- "bytes"
- "fmt"
- "math/big"
-
- "github.com/dogechain-lab/fastrlp"
- iradix "github.com/hashicorp/go-immutable-radix"
-
- "github.com/dogechain-lab/dogechain/crypto"
- "github.com/dogechain-lab/dogechain/types"
-)
-
-type State interface {
- NewSnapshotAt(types.Hash) (Snapshot, error)
- NewSnapshot() Snapshot
- GetCode(hash types.Hash) ([]byte, bool)
-}
-
-type Snapshot interface {
- snapshotReader
-
- Commit(objs []*Object) (Snapshot, []byte, error)
-}
-
-// account trie
-type accountTrie interface {
- Get(k []byte) ([]byte, bool)
-}
-
-// Account is the account reference in the ethereum state
-type Account struct {
- Nonce uint64
- Balance *big.Int
- Root types.Hash
- CodeHash []byte
-}
-
-func (a *Account) MarshalWith(ar *fastrlp.Arena) *fastrlp.Value {
- v := ar.NewArray()
- v.Set(ar.NewUint(a.Nonce))
- v.Set(ar.NewBigInt(a.Balance))
- v.Set(ar.NewBytes(a.Root.Bytes()))
- v.Set(ar.NewBytes(a.CodeHash))
-
- return v
-}
-
-var accountParserPool fastrlp.ParserPool
-
-func (a *Account) UnmarshalRlp(b []byte) error {
- p := accountParserPool.Get()
- defer accountParserPool.Put(p)
-
- v, err := p.Parse(b)
- if err != nil {
- return err
- }
-
- elems, err := v.GetElems()
-
- if err != nil {
- return err
- }
-
- if len(elems) < 4 {
- return fmt.Errorf("incorrect number of elements to decode account, expected at least 4 but found %d",
- len(elems))
- }
-
- // nonce
- if a.Nonce, err = elems[0].GetUint64(); err != nil {
- return err
- }
- // balance
- if a.Balance == nil {
- a.Balance = new(big.Int)
- }
-
- if err = elems[1].GetBigInt(a.Balance); err != nil {
- return err
- }
- // root
- if err = elems[2].GetHash(a.Root[:]); err != nil {
- return err
- }
- // codeHash
- if a.CodeHash, err = elems[3].GetBytes(a.CodeHash[:0]); err != nil {
- return err
- }
-
- return nil
-}
-
-func (a *Account) String() string {
- return fmt.Sprintf("%d %s", a.Nonce, a.Balance.String())
-}
-
-func (a *Account) Copy() *Account {
- aa := new(Account)
-
- aa.Balance = new(big.Int).SetBytes(a.Balance.Bytes())
- aa.Nonce = a.Nonce
- aa.CodeHash = a.CodeHash
- aa.Root = a.Root
-
- return aa
-}
-
-var emptyCodeHash = crypto.Keccak256(nil)
-
-// StateObject is the internal representation of the account
-type StateObject struct {
- Account *Account
- Code []byte
- Suicide bool
- Deleted bool
- DirtyCode bool
- Txn *iradix.Txn
-}
-
-func (s *StateObject) Empty() bool {
- return s.Account.Nonce == 0 && s.Account.Balance.Sign() == 0 && bytes.Equal(s.Account.CodeHash, emptyCodeHash)
-}
-
-// Copy makes a copy of the state object
-func (s *StateObject) Copy() *StateObject {
- ss := new(StateObject)
-
- // copy account
- ss.Account = s.Account.Copy()
-
- ss.Suicide = s.Suicide
- ss.Deleted = s.Deleted
- ss.DirtyCode = s.DirtyCode
- ss.Code = s.Code
-
- if s.Txn != nil {
- ss.Txn = s.Txn.CommitOnly().Txn()
- }
-
- return ss
-}
-
-// Object is the serialization of the radix object (can be merged to StateObject?).
-type Object struct {
- Address types.Address
- CodeHash types.Hash
- Balance *big.Int
- Root types.Hash
- Nonce uint64
- Deleted bool
-
- // TODO: Move this to executor
- DirtyCode bool
- Code []byte
-
- Storage []*StorageObject
-}
-
-// StorageObject is an entry in the storage
-type StorageObject struct {
- Deleted bool
- Key []byte
- Val []byte
-}
diff --git a/state/state_object.go b/state/state_object.go
new file mode 100644
index 0000000000..2f24840351
--- /dev/null
+++ b/state/state_object.go
@@ -0,0 +1,198 @@
+package state
+
+import (
+ "bytes"
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+ iradix "github.com/hashicorp/go-immutable-radix"
+)
+
+var emptyCodeHash = types.EmptyCodeHash.Bytes()
+
+type State interface {
+ NewSnapshotAt(types.Hash) (Snapshot, error)
+ NewSnapshot() Snapshot
+ GetCode(hash types.Hash) ([]byte, bool)
+}
+
+type Snapshot interface {
+ snapshotReader
+
+ // Change object state root if there is any update of storage
+ Commit(objs []*stypes.Object) (Snapshot, []byte, error)
+}
+
+// stateObject is the internal representation of the account
+type stateObject struct {
+ account *stypes.Account
+ code []byte
+
+ // status fields, open readable?
+ suicide bool
+ deleted bool
+ dirtyCode bool
+
+ // live object radix trie Transaction. Set it only when there is a trie
+ radixTxn *iradix.Txn
+
+ // associated transiction Transaction, to update its journal
+ transitionTxn *Txn
+
+ // for quick search, inner fileds only
+ address types.Address
+ addrHash types.Hash
+}
+
+// newStateObject create a new state object
+func newStateObject(transitionTxn *Txn, address types.Address, account *stypes.Account) *stateObject {
+ if account == nil {
+ account = new(stypes.Account)
+ }
+
+ if account.Balance == nil {
+ account.Balance = new(big.Int)
+ }
+
+ if account.CodeHash == nil {
+ account.CodeHash = emptyCodeHash
+ }
+
+ if account.StorageRoot == (types.Hash{}) {
+ account.StorageRoot = emptyStateHash
+ }
+
+ return stateObjectWithAddress(transitionTxn, address, account)
+}
+
+func stateObjectWithAddress(transitionTxn *Txn, address types.Address, account *stypes.Account) *stateObject {
+ return &stateObject{
+ account: account,
+ address: address,
+ addrHash: crypto.Keccak256Hash(address[:]),
+ transitionTxn: transitionTxn,
+ }
+}
+
+func (s *stateObject) Empty() bool {
+ return s.Nonce() == 0 && s.Balance().Sign() == 0 && bytes.Equal(s.CodeHash(), emptyCodeHash)
+}
+
+// Copy makes a copy of the state object
+func (s *stateObject) Copy() *stateObject {
+ ss := new(stateObject)
+
+ // copy account
+ ss.account = s.account.Copy()
+
+ ss.suicide = s.suicide
+ ss.deleted = s.deleted
+ ss.dirtyCode = s.dirtyCode
+ ss.code = s.code
+
+ if s.radixTxn != nil {
+ ss.radixTxn = s.radixTxn.CommitOnly().Txn()
+ }
+
+ ss.transitionTxn = s.transitionTxn
+
+ // search key
+ ss.address = s.address
+ ss.addrHash = s.addrHash
+
+ return ss
+}
+
+func (s *stateObject) AddBalance(balance *big.Int) {
+ s.SetBalance(new(big.Int).Add(s.Balance(), balance))
+}
+
+func (s *stateObject) SubBalance(balance *big.Int) {
+ s.SetBalance(new(big.Int).Sub(s.Balance(), balance))
+}
+
+func (s *stateObject) SetBalance(balance *big.Int) {
+ s.transitionTxn.journal.append(balanceChange{
+ account: &s.address,
+ prev: new(big.Int).Set(s.Balance()),
+ })
+ s.setBalance(balance)
+}
+
+func (s *stateObject) setBalance(balance *big.Int) {
+ s.account.Balance = balance
+}
+
+// Address returns the address of the contract/account
+func (s *stateObject) Address() types.Address {
+ return s.address
+}
+
+func (s *stateObject) AddressHash() types.Hash {
+ return s.addrHash
+}
+
+// Code returns the contract code associated with this object, if any.
+func (s *stateObject) Code() []byte {
+ if s.dirtyCode {
+ return s.code
+ }
+
+ if bytes.Equal(s.CodeHash(), emptyCodeHash) {
+ return nil
+ }
+
+ code, _ := s.transitionTxn.snapshot.GetCode(types.BytesToHash(s.CodeHash()))
+
+ // cache the code, but it is not dirty
+ s.code = code
+
+ return code
+}
+
+func (s *stateObject) SetCode(codeHash types.Hash, code []byte) {
+ prevcode := s.Code()
+ // journal change
+ s.transitionTxn.journal.append(codeChange{
+ account: &s.address,
+ prevhash: s.CodeHash(),
+ prevcode: prevcode,
+ })
+ s.setCode(codeHash, code)
+}
+
+func (s *stateObject) setCode(codeHash types.Hash, code []byte) {
+ s.code = code
+ s.dirtyCode = true
+ s.account.CodeHash = codeHash[:]
+}
+
+func (s *stateObject) CodeHash() []byte {
+ return s.account.CodeHash
+}
+
+func (s *stateObject) Nonce() uint64 {
+ return s.account.Nonce
+}
+
+func (s *stateObject) SetNonce(nonce uint64) {
+ s.transitionTxn.journal.append(nonceChange{
+ account: &s.address,
+ prev: s.Nonce(),
+ })
+ s.setNonce(nonce)
+}
+
+func (s *stateObject) setNonce(nonce uint64) {
+ s.account.Nonce = nonce
+}
+
+func (s *stateObject) Balance() *big.Int {
+ return s.account.Balance
+}
+
+func (s *stateObject) StorageRoot() types.Hash {
+ return s.account.StorageRoot
+}
diff --git a/state/stypes/account.go b/state/stypes/account.go
new file mode 100644
index 0000000000..f6810fe9f2
--- /dev/null
+++ b/state/stypes/account.go
@@ -0,0 +1,88 @@
+package stypes
+
+import (
+ "fmt"
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/dogechain-lab/fastrlp"
+)
+
+// Account is the account reference in the ethereum state
+type Account struct {
+ Nonce uint64
+ Balance *big.Int
+ StorageRoot types.Hash // storage root
+ CodeHash []byte
+}
+
+func (a *Account) MarshalWith(ar *fastrlp.Arena) *fastrlp.Value {
+ v := ar.NewArray()
+ v.Set(ar.NewUint(a.Nonce))
+ v.Set(ar.NewBigInt(a.Balance))
+ v.Set(ar.NewBytes(a.StorageRoot.Bytes()))
+ v.Set(ar.NewBytes(a.CodeHash))
+
+ return v
+}
+
+func (a *Account) UnmarshalRlp(b []byte) error {
+ v, err := types.RlpUnmarshal(b)
+ if err != nil {
+ return err
+ }
+
+ elems, err := v.GetElems()
+
+ if err != nil {
+ return err
+ }
+
+ if len(elems) < 4 {
+ return fmt.Errorf("incorrect number of elements to decode account, expected at least 4 but found %d",
+ len(elems))
+ }
+
+ // nonce
+ if a.Nonce, err = elems[0].GetUint64(); err != nil {
+ return err
+ }
+ // balance
+ if a.Balance == nil {
+ a.Balance = new(big.Int)
+ }
+
+ if err = elems[1].GetBigInt(a.Balance); err != nil {
+ return err
+ }
+ // root
+ if err = elems[2].GetHash(a.StorageRoot[:]); err != nil {
+ return err
+ }
+ // codeHash
+ if a.CodeHash, err = elems[3].GetBytes(a.CodeHash[:0]); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (a *Account) String() string {
+ return fmt.Sprintf("%d %s", a.Nonce, a.Balance.String())
+}
+
+func (a *Account) Copy() *Account {
+ aa := new(Account)
+
+ aa.Balance = new(big.Int)
+
+ if a.Balance != nil {
+ aa.Balance.Set(a.Balance)
+ }
+
+ aa.Nonce = a.Nonce
+ aa.CodeHash = a.CodeHash
+ aa.StorageRoot = a.StorageRoot
+
+ return aa
+}
diff --git a/state/stypes/object.go b/state/stypes/object.go
new file mode 100644
index 0000000000..ac20cf1f12
--- /dev/null
+++ b/state/stypes/object.go
@@ -0,0 +1,30 @@
+package stypes
+
+import (
+ "math/big"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Object is the serialization of the radix object (can be merged to StateObject?).
+type Object struct {
+ Address types.Address
+ CodeHash types.Hash
+ Balance *big.Int
+ Root types.Hash
+ Nonce uint64
+ Deleted bool
+
+ // TODO: Move this to executor
+ DirtyCode bool
+ Code []byte
+
+ Storage []*StorageObject
+}
+
+// StorageObject is an entry in the storage
+type StorageObject struct {
+ Deleted bool
+ Key []byte
+ Val []byte
+}
diff --git a/state/testing.go b/state/testing.go
index 5d2ed57196..aa3d9c27b2 100644
--- a/state/testing.go
+++ b/state/testing.go
@@ -40,40 +40,40 @@ type buildPreState func(p PreStates) Snapshot
func TestState(t *testing.T, buildPreState buildPreState) {
t.Helper()
- t.Run("", func(t *testing.T) {
+ t.Run("write state", func(t *testing.T) {
testWriteState(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("write empty state", func(t *testing.T) {
testWriteEmptyState(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("update state with empty", func(t *testing.T) {
testUpdateStateWithEmpty(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("suicide account in prestate", func(t *testing.T) {
testSuicideAccountInPreState(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("suicide account", func(t *testing.T) {
testSuicideAccount(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("suicide account with data", func(t *testing.T) {
testSuicideAccountWithData(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("suicide coinbase", func(t *testing.T) {
testSuicideCoinbase(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("suicide with intermediate commit", func(t *testing.T) {
testSuicideWithIntermediateCommit(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("restart refunds", func(t *testing.T) {
testRestartRefunds(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("change prestate account balance to zero", func(t *testing.T) {
testChangePrestateAccountBalanceToZero(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("change account balance to zero", func(t *testing.T) {
testChangeAccountBalanceToZero(t, buildPreState)
})
- t.Run("", func(t *testing.T) {
+ t.Run("delete common state root", func(t *testing.T) {
testDeleteCommonStateRoot(t, buildPreState)
})
}
diff --git a/state/txn.go b/state/txn.go
index c071dac2f7..e9f587f076 100644
--- a/state/txn.go
+++ b/state/txn.go
@@ -1,14 +1,19 @@
package state
import (
+ "bytes"
"math/big"
-
- iradix "github.com/hashicorp/go-immutable-radix"
+ "sort"
"github.com/dogechain-lab/dogechain/chain"
"github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/state/utils"
"github.com/dogechain-lab/dogechain/types"
+ "github.com/dogechain-lab/fastrlp"
+ iradix "github.com/hashicorp/go-immutable-radix"
)
var emptyStateHash = types.StringToHash("0x56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
@@ -23,16 +28,36 @@ var (
// snapshotReader is snapshot read only APIs
type snapshotReader interface {
- GetStorage(addr types.Address, root types.Hash, key types.Hash) (types.Hash, error)
- GetAccount(addr types.Address) (*Account, error)
+ GetStorage(addr types.Address, storageRoot types.Hash, key types.Hash) (types.Hash, error)
+ GetAccount(addr types.Address) (*stypes.Account, error)
GetCode(hash types.Hash) ([]byte, bool)
}
+type revision struct {
+ id int
+ journalIndex int
+}
+
// Txn is a reference of the state
type Txn struct {
snapshot snapshotReader
snapshots []*iradix.Tree
- txn *iradix.Txn
+ // current radix trie transaction, caching live objects, including stateobject,
+ // log, refund
+ txn *iradix.Txn
+
+ // for caching world state
+ snap snapshot.Snapshot
+ snapDestructs map[types.Hash]struct{} // deleted and waiting for destruction
+ snapAccounts map[types.Hash][]byte // live snapshot accounts
+ // live snapshot storages map. [accountHash]map[slotHash]hashValue
+ // keep the structrue same with persistence layer
+ snapStorage map[types.Hash]map[types.Hash][]byte
+
+ // Journal of state modifications. This is the backbone of
+ // RevertToSnapshot when enabling cache snapshot.
+ journal *journal
+ validRevisions []revision
}
func NewTxn(snapshot Snapshot) *Txn {
@@ -43,14 +68,39 @@ func newTxn(snapshot snapshotReader) *Txn {
i := iradix.New()
return &Txn{
- snapshot: snapshot,
- snapshots: []*iradix.Tree{},
- txn: i.Txn(),
+ snapshot: snapshot,
+ snapshots: []*iradix.Tree{},
+ txn: i.Txn(),
+ journal: newJournal(),
+ validRevisions: []revision{},
}
}
-func (txn *Txn) hashit(src []byte) []byte {
- return crypto.Keccak256(src)
+// SetSnap sets up the world state snapshot
+func (txn *Txn) SetSnap(
+ snap snapshot.Snapshot,
+) {
+ txn.snap = snap
+ if txn.snap != nil {
+ txn.snapDestructs = make(map[types.Hash]struct{})
+ txn.snapAccounts = make(map[types.Hash][]byte)
+ txn.snapStorage = make(map[types.Hash]map[types.Hash][]byte)
+ }
+}
+
+func (txn *Txn) GetSnapObjects() (
+ snapDestructs map[types.Hash]struct{},
+ snapAccounts map[types.Hash][]byte,
+ snapStorage map[types.Hash]map[types.Hash][]byte,
+) {
+ return txn.snapDestructs, txn.snapAccounts, txn.snapStorage
+}
+
+// CleanSnap cleans current snapshots
+func (txn *Txn) CleanSnap() {
+ if txn.snap != nil {
+ txn.snap, txn.snapDestructs, txn.snapAccounts, txn.snapStorage = nil, nil, nil, nil
+ }
}
// Snapshot takes a snapshot at this point in time
@@ -59,6 +109,8 @@ func (txn *Txn) Snapshot() int {
id := len(txn.snapshots)
txn.snapshots = append(txn.snapshots, t)
+ // append valid revision for journal
+ txn.validRevisions = append(txn.validRevisions, revision{id, txn.journal.length()})
return id
}
@@ -71,54 +123,116 @@ func (txn *Txn) RevertToSnapshot(id int) {
tree := txn.snapshots[id]
txn.txn = tree.Txn()
+
+ // If state snapshotting is active, we should reset to its original value,
+ // otherwise the resurrect account or transient update will be persisted
+ // into snapshot tree, and make the whole worldstate damage.
+ idx := sort.Search(len(txn.validRevisions), func(i int) bool {
+ return txn.validRevisions[i].id >= id
+ })
+
+ // Find the snapshot in the stack of valid snapshots.
+ cachedSnapshot := txn.validRevisions[idx].journalIndex
+
+ // Replay the journal to undo changes
+ txn.journal.revert(txn, cachedSnapshot)
+ // remove invalidated snapshots
+ txn.validRevisions = txn.validRevisions[:idx]
}
-// GetAccount returns an account
-func (txn *Txn) GetAccount(addr types.Address) (*Account, bool) {
- object, exists := txn.getStateObject(addr)
- if !exists {
- return nil, false
+func (txn *Txn) clearJournal() {
+ if len(txn.journal.entries) > 0 {
+ txn.journal = newJournal()
+ }
+
+ // Snapshots can be created without journal entries
+ txn.validRevisions = txn.validRevisions[:0]
+}
+
+func (txn *Txn) getStateObject(addr types.Address) (*stateObject, bool) {
+ if obj := txn.getDeletedStateObject(addr); obj != nil && !obj.deleted {
+ return obj, true
}
- return object.Account, true
+ return nil, false
}
-func (txn *Txn) getStateObject(addr types.Address) (*StateObject, bool) {
+func (txn *Txn) getDeletedStateObject(addr types.Address) *stateObject {
// Try to get state from radix tree which holds transient states during block processing first
- val, exists := txn.txn.Get(addr.Bytes())
- if exists {
- obj := val.(*StateObject) //nolint:forcetypeassert
- if obj.Deleted {
- return nil, false
+ if val, exists := txn.txn.Get(addr.Bytes()); exists {
+ obj := val.(*stateObject) //nolint:forcetypeassert
+
+ return obj.Copy()
+ }
+
+ var (
+ account *stypes.Account
+ )
+
+ // If no transient objects are available, attempt to use snapshots
+ if txn.snap != nil {
+ if acc, err := txn.snap.Account(crypto.Keccak256Hash(addr.Bytes())); err == nil { // got
+ if acc == nil {
+ return nil
+ }
+
+ account = acc
+
+ if account.StorageRoot == types.ZeroHash {
+ account.StorageRoot = types.EmptyRootHash
+ }
+
+ if len(account.CodeHash) == 0 {
+ account.CodeHash = emptyCodeHash
+ }
}
+ }
- return obj.Copy(), true
+ // If snapshot unavailable or reading from it failed, load from the database
+ if account == nil {
+ var err error
+
+ account, err = txn.snapshot.GetAccount(addr)
+ if err != nil {
+ return nil
+ } else if account == nil {
+ return nil
+ }
}
- account, err := txn.snapshot.GetAccount(addr)
- if err != nil {
- return nil, false
- } else if account == nil {
- return nil, false
+ return stateObjectWithAddress(txn, addr, account.Copy())
+}
+
+// updateSnapAccount updates snap account by object
+//
+// update live object or revert to some journaled object
+func (txn *Txn) updateSnapAccount(object *stateObject) {
+ if txn.snap == nil || object == nil {
+ return
}
- obj := &StateObject{
- Account: account.Copy(),
+ if object.suicide {
+ delete(txn.snapAccounts, object.addrHash)
+
+ return
}
- return obj, true
+ // If state snapshotting is active, cache the data til commit. Note, this
+ // update mechanism is not symmetric to the deletion, because whereas it is
+ // enough to track account updates at commit time, deletions need tracking
+ // at transaction boundary level to ensure we capture state clearing.
+ txn.snapAccounts[object.AddressHash()] = snapshot.SlimAccountRLP(
+ object.Nonce(),
+ object.Balance(),
+ object.StorageRoot(),
+ object.CodeHash(),
+ )
}
-func (txn *Txn) upsertAccount(addr types.Address, create bool, f func(object *StateObject)) {
+func (txn *Txn) upsertAccount(addr types.Address, create bool, f func(object *stateObject)) {
object, exists := txn.getStateObject(addr)
if !exists && create {
- object = &StateObject{
- Account: &Account{
- Balance: big.NewInt(0),
- CodeHash: emptyCodeHash,
- Root: emptyStateHash,
- },
- }
+ object = newStateObject(txn, addr, nil)
}
// run the callback to modify the account
@@ -127,23 +241,28 @@ func (txn *Txn) upsertAccount(addr types.Address, create bool, f func(object *St
if object != nil {
txn.txn.Insert(addr.Bytes(), object)
}
+
+ txn.updateSnapAccount(object)
}
func (txn *Txn) AddSealingReward(addr types.Address, balance *big.Int) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- if object.Suicide {
- *object = *newStateObject(txn)
- object.Account.Balance.SetBytes(balance.Bytes())
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ if object.suicide {
+ // create a only balance object if it suidcide
+ *object = *newStateObject(txn, addr, &stypes.Account{
+ Balance: new(big.Int).SetBytes(balance.Bytes()),
+ })
} else {
- object.Account.Balance.Add(object.Account.Balance, balance)
+ object.AddBalance(balance)
}
})
}
// AddBalance adds balance
func (txn *Txn) AddBalance(addr types.Address, balance *big.Int) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.Balance.Add(object.Account.Balance, balance)
+ // update the account even it add 0
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.AddBalance(balance)
})
}
@@ -159,8 +278,8 @@ func (txn *Txn) SubBalance(addr types.Address, amount *big.Int) error {
return runtime.ErrNotEnoughFunds
}
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.Balance.Sub(object.Account.Balance, amount)
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.SubBalance(amount)
})
return nil
@@ -168,8 +287,8 @@ func (txn *Txn) SubBalance(addr types.Address, amount *big.Int) error {
// SetBalance sets the balance
func (txn *Txn) SetBalance(addr types.Address, balance *big.Int) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.Balance.SetBytes(balance.Bytes())
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.SetBalance(balance)
})
}
@@ -180,7 +299,7 @@ func (txn *Txn) GetBalance(addr types.Address) *big.Int {
return big.NewInt(0)
}
- return object.Account.Balance
+ return object.Balance()
}
func (txn *Txn) EmitLog(addr types.Address, topics []types.Hash, data []byte) {
@@ -306,21 +425,24 @@ func (txn *Txn) SetState(
key,
value types.Hash,
) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- if object.Txn == nil {
- object.Txn = iradix.New().Txn()
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ if object.radixTxn == nil {
+ object.radixTxn = iradix.New().Txn()
}
if value == zeroHash {
- object.Txn.Insert(key.Bytes(), nil)
+ object.radixTxn.Insert(key.Bytes(), nil)
} else {
- object.Txn.Insert(key.Bytes(), value.Bytes())
+ object.radixTxn.Insert(key.Bytes(), value.Bytes())
}
})
}
// GetState returns the state of the address at a given key
+//
+// The state might be transient, remember to query the not committed trie
func (txn *Txn) GetState(addr types.Address, slot types.Hash) (types.Hash, error) {
+ // check account existence, and get its latest storage root
object, exists := txn.getStateObject(addr)
if !exists {
return types.Hash{}, nil
@@ -329,8 +451,8 @@ func (txn *Txn) GetState(addr types.Address, slot types.Hash) (types.Hash, error
// Try to get account state from radix tree first
// Because the latest account state should be in in-memory radix tree
// if account state update happened in previous transactions of same block
- if object.Txn != nil {
- if val, ok := object.Txn.Get(slot.Bytes()); ok {
+ if object.radixTxn != nil {
+ if val, ok := object.radixTxn.Get(slot.Bytes()); ok {
if val == nil {
return types.Hash{}, nil
}
@@ -339,23 +461,23 @@ func (txn *Txn) GetState(addr types.Address, slot types.Hash) (types.Hash, error
}
}
- // get it from storage
- return txn.snapshot.GetStorage(addr, object.Account.Root, slot)
+ // query the committed state
+ return txn.getCommittedObjectState(object, slot)
}
// Nonce
// IncrNonce increases the nonce of the address
func (txn *Txn) IncrNonce(addr types.Address) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.Nonce++
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.SetNonce(object.Nonce() + 1)
})
}
-// SetNonce reduces the balance
+// SetNonce set nonce directly
func (txn *Txn) SetNonce(addr types.Address, nonce uint64) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.Nonce = nonce
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.SetNonce(nonce)
})
}
@@ -366,17 +488,15 @@ func (txn *Txn) GetNonce(addr types.Address) uint64 {
return 0
}
- return object.Account.Nonce
+ return object.Nonce()
}
// Code
// SetCode sets the code for an address
func (txn *Txn) SetCode(addr types.Address, code []byte) {
- txn.upsertAccount(addr, true, func(object *StateObject) {
- object.Account.CodeHash = crypto.Keccak256(code)
- object.DirtyCode = true
- object.Code = code
+ txn.upsertAccount(addr, true, func(object *stateObject) {
+ object.SetCode(crypto.Keccak256Hash(code), code)
})
}
@@ -386,14 +506,7 @@ func (txn *Txn) GetCode(addr types.Address) []byte {
return nil
}
- if object.DirtyCode {
- return object.Code
- }
-
- // TODO: handle error
- code, _ := txn.snapshot.GetCode(types.BytesToHash(object.Account.CodeHash))
-
- return code
+ return object.Code()
}
func (txn *Txn) GetCodeSize(addr types.Address) int {
@@ -406,22 +519,35 @@ func (txn *Txn) GetCodeHash(addr types.Address) types.Hash {
return types.Hash{}
}
- return types.BytesToHash(object.Account.CodeHash)
+ return types.BytesToHash(object.CodeHash())
}
// Suicide marks the given account as suicided
func (txn *Txn) Suicide(addr types.Address) bool {
var suicided bool
- txn.upsertAccount(addr, false, func(object *StateObject) {
- if object == nil || object.Suicide {
+ txn.upsertAccount(addr, false, func(object *stateObject) {
+ change := suicideChange{
+ account: &addr,
+ prevbalance: new(big.Int),
+ }
+ // cache prev object
+ if object != nil {
+ change.prev = object.suicide
+ change.prevbalance.Set(object.Balance())
+ }
+ // journal change
+ txn.journal.append(change)
+
+ // update value
+ if object == nil || object.suicide {
suicided = false
} else {
suicided = true
- object.Suicide = true
+ object.suicide = true
}
if object != nil {
- object.Account.Balance = new(big.Int)
+ object.SetBalance(new(big.Int))
}
})
@@ -432,7 +558,7 @@ func (txn *Txn) Suicide(addr types.Address) bool {
func (txn *Txn) HasSuicided(addr types.Address) bool {
object, exists := txn.getStateObject(addr)
- return exists && object.Suicide
+ return exists && object.suicide
}
// Refund
@@ -446,6 +572,7 @@ func (txn *Txn) SubRefund(gas uint64) {
txn.txn.Insert(refundIndex, refund)
}
+// Logs returns and clears all logs held in txn trie
func (txn *Txn) Logs() []*types.Log {
data, exists := txn.txn.Get(logIndex)
if !exists {
@@ -467,24 +594,48 @@ func (txn *Txn) GetRefund() uint64 {
return data.(uint64)
}
+func (txn *Txn) getCommittedObjectState(obj *stateObject, slot types.Hash) (types.Hash, error) {
+ if txn.snap != nil {
+ addrHash := obj.addrHash
+ // If the object was destructed in *this* block (and potentially resurrected),
+ // the storage has been cleared out, and we should *not* consult the previous
+ // snapshot about any storage values. The only possible alternatives are:
+ // 1) resurrect happened, and new slot values were set -- those should
+ // have been handles via pendingStorage above.
+ // 2) we don't have new values, and can deliver empty response back
+ if _, destructed := txn.snapDestructs[addrHash]; destructed {
+ return types.Hash{}, nil
+ }
+
+ // query it from cached snapshot
+ if enc, err := txn.snap.Storage(addrHash, crypto.Keccak256Hash(slot.Bytes())); err == nil { // found
+ return utils.StorageBytesToHash(enc)
+ }
+ }
+
+ // If the snapshot is unavailable or reading from it fails, load from the database.
+ return txn.snapshot.GetStorage(obj.address, obj.StorageRoot(), slot)
+}
+
// GetCommittedState returns the state of the address in the trie
-func (txn *Txn) GetCommittedState(addr types.Address, key types.Hash) (types.Hash, error) {
+//
+// The state is committed (persisted, too).
+func (txn *Txn) GetCommittedState(addr types.Address, slot types.Hash) (types.Hash, error) {
+ // If the snapshot is unavailable or reading from it fails, load from the database.
obj, ok := txn.getStateObject(addr)
if !ok {
return types.Hash{}, nil
}
- return txn.snapshot.GetStorage(addr, obj.Account.Root, key)
+ return txn.getCommittedObjectState(obj, slot)
}
func (txn *Txn) TouchAccount(addr types.Address) {
- txn.upsertAccount(addr, true, func(obj *StateObject) {
+ txn.upsertAccount(addr, true, func(obj *stateObject) {
})
}
-// TODO, check panics with this ones
-
func (txn *Txn) Exist(addr types.Address) bool {
_, exists := txn.getStateObject(addr)
@@ -500,42 +651,50 @@ func (txn *Txn) Empty(addr types.Address) bool {
return obj.Empty()
}
-func newStateObject(txn *Txn) *StateObject {
- return &StateObject{
- Account: &Account{
- Balance: big.NewInt(0),
- CodeHash: emptyCodeHash,
- Root: emptyStateHash,
- },
+func (txn *Txn) CreateAccount(addr types.Address) {
+ // prev might have been deleted
+ prev := txn.getDeletedStateObject(addr)
+
+ // cache reset change
+ var prevdestruct bool
+
+ if txn.snap != nil && prev != nil {
+ // destruct object when already deleted
+ prevAddrHash := prev.AddressHash()
+ _, prevdestruct = txn.snapDestructs[prevAddrHash]
+
+ if !prevdestruct {
+ txn.snapDestructs[prevAddrHash] = struct{}{}
+ }
}
-}
-func (txn *Txn) CreateAccount(addr types.Address) {
- obj := &StateObject{
- Account: &Account{
- Balance: big.NewInt(0),
- CodeHash: emptyCodeHash,
- Root: emptyStateHash,
- },
+ // create a new object no matter exists or not
+ obj := newStateObject(txn, addr, nil)
+
+ if prev != nil { // journal reset status
+ txn.journal.append(resetObjectChange{prev: prev, prevdestruct: prevdestruct})
}
- prev, ok := txn.getStateObject(addr)
- if ok {
- obj.Account.Balance.SetBytes(prev.Account.Balance.Bytes())
+ if prev != nil && !prev.deleted {
+ obj.SetBalance(prev.Balance())
}
+ // insert it to itrie
txn.txn.Insert(addr.Bytes(), obj)
}
+// CleanDeleteObjects clears deleted objects and invalid its journals.
+//
+// Byzantium fork is alwals on, so reverting across transactions is no allow.
func (txn *Txn) CleanDeleteObjects(deleteEmptyObjects bool) {
remove := [][]byte{}
txn.txn.Root().Walk(func(k []byte, v interface{}) bool {
- a, ok := v.(*StateObject)
+ a, ok := v.(*stateObject)
if !ok {
return false
}
- if a.Suicide || a.Empty() && deleteEmptyObjects {
+ if a.suicide || a.Empty() && deleteEmptyObjects {
remove = append(remove, k)
}
@@ -548,57 +707,100 @@ func (txn *Txn) CleanDeleteObjects(deleteEmptyObjects bool) {
panic("it should not happen")
}
- obj, ok := v.(*StateObject)
+ obj, ok := v.(*stateObject)
if !ok {
panic("it should not happen")
}
obj2 := obj.Copy()
- obj2.Deleted = true
+ obj2.deleted = true
txn.txn.Insert(k, obj2)
}
// delete refunds
txn.txn.Delete(refundIndex)
+
+ // Invalidate journal because reverting across transactions is not allowed.
+ txn.clearJournal()
}
// func (txn *Txn) Commit(deleteEmptyObjects bool) (Snapshot, []byte) {
-func (txn *Txn) Commit(deleteEmptyObjects bool) []*Object {
+func (txn *Txn) Commit(deleteEmptyObjects bool) []*stypes.Object {
txn.CleanDeleteObjects(deleteEmptyObjects)
x := txn.txn.Commit()
// Do a more complex thing for now
- objs := []*Object{}
+ objs := []*stypes.Object{}
x.Root().Walk(func(k []byte, v interface{}) bool {
- a, ok := v.(*StateObject)
+ sobj, ok := v.(*stateObject)
if !ok {
// We also have logs, avoid those
return false
}
- obj := &Object{
- Nonce: a.Account.Nonce,
- Address: types.BytesToAddress(k),
- Balance: a.Account.Balance,
- Root: a.Account.Root,
- CodeHash: types.BytesToHash(a.Account.CodeHash),
- DirtyCode: a.DirtyCode,
- Code: a.Code,
+ addr := types.BytesToAddress(k)
+
+ // for storage value marshaling
+ storeAr := fastrlp.DefaultArenaPool.Get()
+ defer fastrlp.DefaultArenaPool.Put(storeAr)
+
+ obj := &stypes.Object{
+ Nonce: sobj.Nonce(),
+ Address: addr,
+ Balance: sobj.Balance(),
+ Root: sobj.StorageRoot(),
+ CodeHash: types.BytesToHash(sobj.CodeHash()),
+ DirtyCode: sobj.dirtyCode,
+ Code: sobj.Code(),
}
- if a.Deleted {
+ if sobj.deleted {
obj.Deleted = true
+
+ // If state snapshotting is active, also mark the destruction there.
+ // Note, we can't do this only at the end of a block because multiple
+ // transactions within the same block might self destruct and then
+ // resurrect an account; but the snapshotter needs both events.
+ if txn.snap != nil {
+ addrHash := sobj.AddressHash()
+ // We need to maintain account deletions explicitly (will remain set indefinitely)
+ txn.snapDestructs[addrHash] = struct{}{}
+ // Clear out any previously updated data (may be recreated via a resurrect)
+ delete(txn.snapAccounts, addrHash)
+ delete(txn.snapStorage, addrHash)
+ }
} else {
- if a.Txn != nil {
- a.Txn.Root().Walk(func(k []byte, v interface{}) bool {
- store := &StorageObject{Key: k}
+ if sobj.radixTxn != nil { // if it has a trie, we need to iterate it
+ sobj.radixTxn.Root().Walk(func(k []byte, v interface{}) bool {
+ store := &stypes.StorageObject{Key: k}
if v == nil {
store.Deleted = true
} else {
- store.Val = v.([]byte) //nolint:forcetypeassert
+ // rlp marshal value here, since snapshot use the same encoding rule.
+ //nolint:forcetypeassert
+ vv := storeAr.NewBytes(bytes.TrimLeft(v.([]byte), "\x00"))
+ store.Val = vv.MarshalTo(nil)
+ }
+
+ // update snapshots storage value
+ if txn.snap != nil {
+ var (
+ // current key is slot, we need slot hash
+ storeHash = crypto.Keccak256Hash(k)
+ storage map[types.Hash][]byte
+ addrHash = sobj.AddressHash()
+ )
+ // create map when not exists
+ if storage = txn.snapStorage[addrHash]; storage == nil {
+ storage = make(map[types.Hash][]byte)
+ txn.snapStorage[addrHash] = storage
+ }
+ // update value. v will be nil if it's deleted
+ storage[storeHash] = store.Val
}
+
obj.Storage = append(obj.Storage, store)
return false
diff --git a/state/txn_test.go b/state/txn_test.go
index 3a4cdb456e..e181fb7795 100644
--- a/state/txn_test.go
+++ b/state/txn_test.go
@@ -5,6 +5,7 @@ import (
"math/big"
"testing"
+ "github.com/dogechain-lab/dogechain/state/stypes"
"github.com/dogechain-lab/dogechain/types"
"github.com/stretchr/testify/assert"
)
@@ -27,13 +28,13 @@ func (m *mockSnapshot) GetStorage(addr types.Address, root types.Hash, key types
return res, nil
}
-func (m *mockSnapshot) GetAccount(addr types.Address) (*Account, error) {
+func (m *mockSnapshot) GetAccount(addr types.Address) (*stypes.Account, error) {
raw, ok := m.state[addr]
if !ok {
return nil, fmt.Errorf("account not found")
}
- acct := &Account{
+ acct := &stypes.Account{
Balance: new(big.Int).SetUint64(raw.Balance),
Nonce: raw.Nonce,
}
diff --git a/state/utils/storage_unmarshaling.go b/state/utils/storage_unmarshaling.go
new file mode 100644
index 0000000000..2389fe8af8
--- /dev/null
+++ b/state/utils/storage_unmarshaling.go
@@ -0,0 +1,27 @@
+package utils
+
+import (
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+func StorageBytesToHash(v []byte) (types.Hash, error) {
+ if len(v) == 0 {
+ return types.Hash{}, nil
+ }
+
+ vv, err := types.RlpUnmarshal(v)
+ if err != nil {
+ return types.Hash{}, err
+ }
+
+ if vv == nil {
+ return types.Hash{}, nil
+ }
+
+ res := []byte{}
+ if res, err = vv.GetBytes(res[:0]); err != nil {
+ return types.Hash{}, err
+ }
+
+ return types.BytesToHash(res), nil
+}
diff --git a/state/utils/storage_unmarshaling_test.go b/state/utils/storage_unmarshaling_test.go
new file mode 100644
index 0000000000..09c6f2ed20
--- /dev/null
+++ b/state/utils/storage_unmarshaling_test.go
@@ -0,0 +1,27 @@
+package utils
+
+import (
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestStorageBytesToHash(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ val []byte
+ want types.Hash
+ }{
+ {nil, types.Hash{}},
+ {[]byte{0x24}, types.StringToHash("0x24")},
+ {[]byte{0x82, 0x01, 0x02}, types.StringToHash("0x0102")},
+ }
+
+ for _, tt := range tests {
+ actual, err := StorageBytesToHash(tt.val)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.want, actual)
+ }
+}
diff --git a/tests/evm_test.go b/tests/evm_test.go
index efbb1bdb4f..7b83c73e25 100644
--- a/tests/evm_test.go
+++ b/tests/evm_test.go
@@ -51,12 +51,12 @@ func testVMCase(t *testing.T, name string, c *VMCase) {
env.GasPrice = types.BytesToHash(c.Exec.GasPrice.Bytes())
env.Origin = c.Exec.Origin
- s, _, root, err := buildState(c.Pre)
+ _, s, _, root, err := buildState(c.Pre)
assert.NoError(t, err)
config := mainnetChainConfig.Forks.At(uint64(env.Number))
- executor := state.NewExecutor(&mainnetChainConfig, s, hclog.NewNullLogger())
+ executor := state.NewExecutor(&mainnetChainConfig, hclog.NewNullLogger(), s)
executor.GetHash = func(*types.Header) func(i uint64) types.Hash {
return vmTestBlockHash
}
@@ -140,6 +140,8 @@ func rlpHashLogs(logs []*types.Log) (res types.Hash) {
}
func TestEVM(t *testing.T) {
+ t.Parallel()
+
folders, err := listFolders(vmTests)
if err != nil {
t.Fatal(err)
@@ -152,13 +154,17 @@ func TestEVM(t *testing.T) {
}
for _, folder := range folders {
+ folder := folder
files, err := listFiles(folder)
if err != nil {
t.Fatal(err)
}
for _, file := range files {
+ file := file
t.Run(file, func(t *testing.T) {
+ t.Parallel()
+
if !strings.HasSuffix(file, ".json") {
return
}
@@ -176,7 +182,6 @@ func TestEVM(t *testing.T) {
for name, cc := range vmcases {
if contains(long, name) && testing.Short() {
t.Skip()
-
continue
}
testVMCase(t, name, cc)
@@ -187,5 +192,144 @@ func TestEVM(t *testing.T) {
}
func vmTestBlockHash(n uint64) types.Hash {
- return types.BytesToHash(crypto.Keccak256([]byte(big.NewInt(int64(n)).String())))
+ return crypto.Keccak256Hash([]byte(big.NewInt(int64(n)).String()))
+}
+
+func TestEVMWithSnapshot(t *testing.T) {
+ t.Parallel()
+
+ folders, err := listFolders(vmTests)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ long := []string{
+ "loop-",
+ "gasprice",
+ "origin",
+ }
+
+ for _, folder := range folders {
+ files, err := listFiles(folder)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, file := range files {
+ file := file
+ t.Run(folder, func(t *testing.T) {
+ t.Parallel()
+
+ if !strings.HasSuffix(file, ".json") {
+ return
+ }
+
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var vmcases map[string]*VMCase
+ if err := json.Unmarshal(data, &vmcases); err != nil {
+ t.Fatal(err)
+ }
+
+ for name, cc := range vmcases {
+ if contains(long, name) && testing.Short() {
+ t.Skip()
+ continue
+ }
+ testVMCaseWithSnapshot(t, name, cc)
+ }
+ })
+ }
+ }
+}
+
+func testVMCaseWithSnapshot(t *testing.T, name string, c *VMCase) {
+ t.Helper()
+
+ env := c.Env.ToEnv(t)
+ env.GasPrice = types.BytesToHash(c.Exec.GasPrice.Bytes())
+ env.Origin = c.Exec.Origin
+
+ triedb, s, _, root, err := buildState(c.Pre)
+ assert.NoError(t, err)
+
+ snaps, err := buildSnapshotTree(triedb, root)
+ assert.NoError(t, err)
+
+ config := mainnetChainConfig.Forks.At(uint64(env.Number))
+
+ executor := state.NewExecutor(&mainnetChainConfig, hclog.NewNullLogger(), s)
+ executor.GetHash = func(*types.Header) func(i uint64) types.Hash {
+ return vmTestBlockHash
+ }
+
+ // set executor snapshot to test state features
+ executor.SetSnaps(snaps)
+
+ e, _ := executor.BeginTxn(root, c.Env.ToHeader(t), env.Coinbase)
+ ctx := e.ContextPtr()
+ ctx.GasPrice = types.BytesToHash(env.GasPrice.Bytes())
+ ctx.Origin = env.Origin
+
+ evmR := evm.NewEVM()
+
+ code := e.GetCode(c.Exec.Address)
+ contract := runtime.NewContractCall(
+ 1,
+ c.Exec.Caller,
+ c.Exec.Caller,
+ c.Exec.Address,
+ c.Exec.Value,
+ c.Exec.GasLimit,
+ code,
+ c.Exec.Data,
+ )
+
+ result := evmR.Run(contract, e, &config)
+
+ if c.Gas == "" {
+ if result.Succeeded() {
+ t.Fatalf("gas unspecified (indicating an error), but VM returned no error")
+ }
+
+ if result.GasLeft > 0 {
+ t.Fatalf("gas unspecified (indicating an error), but VM returned gas remaining > 0")
+ }
+
+ return
+ }
+
+ // check return
+ if c.Out == "" {
+ c.Out = "0x"
+ }
+
+ if ret := hex.EncodeToHex(result.ReturnValue); ret != c.Out {
+ t.Fatalf("return mismatch: got %s, want %s", ret, c.Out)
+ }
+
+ txn := e.Txn()
+
+ // check logs
+ if logs := rlpHashLogs(txn.Logs()); logs != types.StringToHash(c.Logs) {
+ t.Fatalf("logs hash mismatch: got %x, want %x", logs, c.Logs)
+ }
+
+ // check state
+ for addr, alloc := range c.Post {
+ for key, val := range alloc.Storage {
+ if have, err := txn.GetState(addr, key); err != nil || have != val {
+ t.Fatalf("wrong storage value at %s:\n got %s\n want %s\n at address %s\n err: %v",
+ key, have, val, addr, err)
+ }
+ }
+ }
+
+ // check remaining gas
+ if expected := stringToUint64T(t, c.Gas); result.GasLeft != expected {
+ t.Fatalf("gas left mismatch: got %d want %d", result.GasLeft, expected)
+ }
}
diff --git a/tests/state_test.go b/tests/state_test.go
index 4b2b3f7b41..4340b4034f 100644
--- a/tests/state_test.go
+++ b/tests/state_test.go
@@ -9,6 +9,7 @@ import (
"testing"
"github.com/dogechain-lab/dogechain/chain"
+ "github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/helper/hex"
"github.com/dogechain-lab/dogechain/state"
"github.com/dogechain-lab/dogechain/state/runtime/evm"
@@ -27,7 +28,7 @@ type stateCase struct {
Info *info `json:"_info"`
Env *env `json:"env"`
Pre map[types.Address]*chain.GenesisAccount `json:"pre"`
- Post map[string]postState `json:"post"`
+ Post map[string]*postState `json:"post"`
Transaction *stTransaction `json:"transaction"`
}
@@ -48,12 +49,12 @@ func RunSpecificTest(t *testing.T, file string, c stateCase, name, fork string,
t.Fatal(err)
}
- s, snapshot, pastRoot, err := buildState(c.Pre)
+ _, s, snapshot, pastRoot, err := buildState(c.Pre)
assert.NoError(t, err)
forks := config.At(uint64(env.Number))
- xxx := state.NewExecutor(&chain.Params{Forks: config, ChainID: 1}, s, hclog.NewNullLogger())
+ xxx := state.NewExecutor(&chain.Params{Forks: config, ChainID: 1}, hclog.NewNullLogger(), s)
xxx.SetRuntime(precompiled.NewPrecompiled())
xxx.SetRuntime(evm.NewEVM())
@@ -105,7 +106,206 @@ func RunSpecificTest(t *testing.T, file string, c stateCase, name, fork string,
}
}
+func RunSpecificTestWithSnapshot(t *testing.T, file string, c *stateCase, name, fork string, index int, p *postEntry) {
+ t.Helper()
+
+ config, ok := Forks[fork]
+ if !ok {
+ t.Fatalf("config %s not found", fork)
+ }
+
+ env := c.Env.ToEnv(t)
+
+ msg, err := c.Transaction.At(p.Indexes)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ triedb, s, snapshot, pastRoot, err := buildState(c.Pre)
+ assert.NoError(t, err)
+
+ // _, err = buildSnapshotTree(triedb, pastRoot)
+ snaps, err := buildSnapshotTree(triedb, pastRoot)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ forks := config.At(uint64(env.Number))
+
+ xxx := state.NewExecutor(&chain.Params{Forks: config, ChainID: 1}, hclog.NewNullLogger(), s)
+ xxx.SetRuntime(precompiled.NewPrecompiled())
+ xxx.SetRuntime(evm.NewEVM())
+
+ // set executor snapshot to test state features
+ xxx.SetSnaps(snaps)
+
+ xxx.PostHook = func(t *state.Transition) {
+ if name == "failed_tx_xcf416c53" {
+ // create the account
+ t.Txn().TouchAccount(ripemd)
+ // now remove it
+ t.Txn().Suicide(ripemd)
+ }
+ }
+ xxx.GetHash = func(*types.Header) func(i uint64) types.Hash {
+ return vmTestBlockHash
+ }
+
+ executor, _ := xxx.BeginTxn(pastRoot, c.Env.ToHeader(t), env.Coinbase)
+ executor.Apply(msg) //nolint:errcheck
+
+ txn := executor.Txn()
+
+ // mining rewards
+ txn.AddSealingReward(env.Coinbase, big.NewInt(0))
+
+ objs := txn.Commit(forks.EIP155)
+ _, root, err := snapshot.Commit(objs)
+ assert.NoError(t, err)
+
+ if !bytes.Equal(root, p.Root.Bytes()) {
+ t.Fatalf(
+ "root mismatch (%s %s %s %d): expected %s but found %s",
+ file,
+ name,
+ fork,
+ index,
+ p.Root,
+ hex.EncodeToHex(root),
+ )
+ }
+
+ // check post logs
+ if logs := rlpHashLogs(txn.Logs()); logs != p.Logs {
+ t.Fatalf(
+ "logs mismatch (%s, %s %d): expected %s but found %s",
+ name,
+ fork,
+ index,
+ p.Logs.String(),
+ logs.String(),
+ )
+ }
+
+ // update snapshot before checking
+ executor.UpdateSnapshot(types.BytesToHash(root), objs)
+
+ // check pre snapshot
+ parentSnap := snaps.Snapshot(pastRoot)
+ if parentSnap == nil {
+ t.Fatalf("parent snapshot(%s) not generated", pastRoot)
+ }
+
+ for addr, account := range c.Pre {
+ addrhash := crypto.Keccak256Hash(addr.Bytes())
+
+ snapAccount, err := parentSnap.Account(addrhash)
+ if err != nil {
+ t.Fatalf("parent snapshot account(%s) unmarshal failed: %v", addr, err)
+ }
+
+ if account.Balance.Cmp(snapAccount.Balance) != 0 {
+ t.Fatalf(
+ "parent snapshot account(%s) balance not right, want(%s), got(%s)",
+ addr,
+ account.Balance,
+ snapAccount.Balance,
+ )
+ }
+ if len(account.Code) > 0 {
+ codeHash := crypto.Keccak256Hash(account.Code)
+ if codeHash != types.BytesToHash(snapAccount.CodeHash) {
+ t.Fatalf(
+ "parent snapshot account(%s) codehash not right, want(%s), got(%s)",
+ addr,
+ codeHash,
+ hex.EncodeToString(snapAccount.CodeHash),
+ )
+ }
+ }
+ if account.Nonce != snapAccount.Nonce {
+ t.Fatalf(
+ "parent snapshot account(%s) nonce not right, want(%d), got(%d)",
+ addr,
+ account.Nonce,
+ snapAccount.Nonce,
+ )
+ }
+
+ // storage
+ for k, v := range account.Storage {
+ // query storage no matter exists or not
+ sv, err := parentSnap.Storage(addrhash, crypto.Keccak256Hash(k.Bytes()))
+ if err != nil {
+ t.Fatalf(
+ "parent snapshot account(%s) storage(%s) getting failed: %v",
+ addr,
+ k,
+ err,
+ )
+ }
+ // empty hash is "deleted"
+ if v == (types.Hash{}) && len(sv) == 0 {
+ continue
+ }
+ // rlp unmarshal
+ fv, err := types.RlpUnmarshal(sv)
+ if err != nil {
+ t.Fatalf(
+ "parent snapshot account(%s) storage(%s) unmarshal failed: %v",
+ addr,
+ k,
+ err,
+ )
+ }
+ // fastrlp value
+ vv, err := fv.Bytes()
+ if err != nil {
+ t.Fatalf(
+ "parent snapshot account(%s) storage(%s) fastrlp failed: %v",
+ addr,
+ k,
+ err,
+ )
+ }
+ // hash
+ hv := types.BytesToHash(vv)
+ if v != hv {
+ t.Fatalf(
+ "parent snapshot account(%s) storage(%s) not right, want(%s), got(%s)",
+ addr,
+ k,
+ v,
+ hv,
+ )
+ }
+ }
+ }
+
+ // check post snapshot
+ snap := snaps.Snapshot(p.Root)
+ if snap == nil {
+ t.Fatalf("snapshot(%s) not generated", p.Root)
+ }
+
+ // check snapshot from account
+ from, err := snap.Account(crypto.Keccak256Hash(msg.From.Bytes()))
+ if err != nil {
+ t.Fatalf("snapshot account(%s) unmarshal failed: %v", msg.From, err)
+ }
+ if from.Nonce != msg.Nonce+1 {
+ t.Fatalf(
+ "snapshot account(%s) nonce not right, want(%d), got(%d)",
+ msg.From,
+ msg.Nonce,
+ from.Nonce,
+ )
+ }
+}
+
func TestState(t *testing.T) {
+ t.Parallel()
+
long := []string{
"static_Call50000",
"static_Return50000",
@@ -126,27 +326,30 @@ func TestState(t *testing.T) {
}
for _, folder := range folders {
- t.Run(folder, func(t *testing.T) {
- files, err := listFiles(folder)
- if err != nil {
- t.Fatal(err)
- }
+ files, err := listFiles(folder)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, file := range files {
+ file := file
+ t.Run(file, func(t *testing.T) {
+ t.Parallel()
- for _, file := range files {
if !strings.HasSuffix(file, ".json") {
- continue
+ return
}
if contains(long, file) && testing.Short() {
t.Skipf("Long tests are skipped in short mode")
- continue
+ return
}
if contains(skip, file) {
t.Skip()
- continue
+ return
}
data, err := ioutil.ReadFile(file)
@@ -154,19 +357,90 @@ func TestState(t *testing.T) {
t.Fatal(err)
}
- var c map[string]stateCase
+ var c map[string]*stateCase
if err := json.Unmarshal(data, &c); err != nil {
t.Fatal(err)
}
for name, i := range c {
for fork, f := range i.Post {
- for indx, e := range f {
- RunSpecificTest(t, file, i, name, fork, indx, e)
+ for indx, e := range *f {
+ RunSpecificTest(t, file, *i, name, fork, indx, *e)
}
}
}
- }
- })
+ })
+ }
+ }
+}
+
+func TestStateWithSnapshot(t *testing.T) {
+ t.Parallel()
+
+ long := []string{
+ "static_Call50000",
+ "static_Return50000",
+ "static_Call1MB",
+ "stQuadraticComplexityTest",
+ "stTimeConsuming",
+ }
+
+ skip := []string{
+ "RevertPrecompiledTouch",
+ }
+
+ // There are two folders in spec tests, one for the current tests for the Istanbul fork
+ // and one for the legacy tests for the other forks
+ folders, err := listFolders(stateTests, legacyStateTests)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, folder := range folders {
+ files, err := listFiles(folder)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for _, file := range files {
+ file := file
+ t.Run(file, func(t *testing.T) {
+ t.Parallel()
+
+ if !strings.HasSuffix(file, ".json") {
+ return
+ }
+
+ if contains(long, file) && testing.Short() {
+ t.Skipf("Long tests are skipped in short mode")
+
+ return
+ }
+
+ if contains(skip, file) {
+ t.Skip()
+
+ return
+ }
+
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var c map[string]*stateCase
+ if err := json.Unmarshal(data, &c); err != nil {
+ t.Fatal(err)
+ }
+
+ for name, i := range c {
+ for fork, f := range i.Post {
+ for indx, e := range *f {
+ RunSpecificTestWithSnapshot(t, file, i, name, fork, indx, e)
+ }
+ }
+ }
+ })
+ }
}
}
diff --git a/tests/testing.go b/tests/testing.go
index 2da3a47b74..650a9f07f4 100644
--- a/tests/testing.go
+++ b/tests/testing.go
@@ -13,9 +13,12 @@ import (
"github.com/dogechain-lab/dogechain/chain"
"github.com/dogechain-lab/dogechain/crypto"
"github.com/dogechain-lab/dogechain/helper/hex"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
"github.com/dogechain-lab/dogechain/state"
itrie "github.com/dogechain-lab/dogechain/state/immutable-trie"
"github.com/dogechain-lab/dogechain/state/runtime"
+ "github.com/dogechain-lab/dogechain/state/snapshot"
+ "github.com/dogechain-lab/dogechain/trie"
"github.com/dogechain-lab/dogechain/types"
"github.com/hashicorp/go-hclog"
)
@@ -224,8 +227,9 @@ func (e *exec) UnmarshalJSON(input []byte) error {
func buildState(
allocs map[types.Address]*chain.GenesisAccount,
-) (state.State, state.Snapshot, types.Hash, error) {
- s := itrie.NewStateDB(itrie.NewMemoryStorage(), hclog.NewNullLogger(), nil)
+) (itrie.Storage, state.State, state.Snapshot, types.Hash, error) {
+ triedb := memorydb.New()
+ s := itrie.NewStateDB(triedb, hclog.NewNullLogger(), nil)
snap := s.NewSnapshot()
txn := state.NewTxn(snap)
@@ -248,10 +252,31 @@ func buildState(
snap, root, err := snap.Commit(objs)
if err != nil {
- return nil, nil, types.ZeroHash, err
+ return nil, nil, nil, types.Hash{}, err
}
- return s, snap, types.BytesToHash(root), nil
+ return triedb, s, snap, types.BytesToHash(root), nil
+}
+
+func buildSnapshotTree(triedb itrie.Storage, baseRoot types.Hash) (*snapshot.Tree, error) {
+ snapCfg := snapshot.Config{
+ CacheSize: 1, // enough disk cache for test
+ Recovery: false,
+ NoBuild: false,
+ AsyncBuild: false, // build it from start and wait
+ }
+
+ logger := hclog.NewNullLogger()
+ snpTrieDB := trie.NewDatabaseWithConfig(
+ triedb,
+ &trie.Config{
+ Cache: 0, // no need cleans cache for test
+ Journal: "", // empty journal
+ },
+ logger,
+ )
+
+ return snapshot.New(snapCfg, triedb, snpTrieDB, baseRoot, logger, snapshot.NilMetrics())
}
type indexes struct {
@@ -266,7 +291,7 @@ type postEntry struct {
Indexes indexes
}
-type postState []postEntry
+type postState []*postEntry
func (p *postEntry) UnmarshalJSON(input []byte) error {
type stateUnmarshall struct {
diff --git a/trie/committer.go b/trie/committer.go
new file mode 100644
index 0000000000..72bea0d77a
--- /dev/null
+++ b/trie/committer.go
@@ -0,0 +1,239 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// leaf represents a trie leaf node
+type leaf struct {
+ blob []byte // raw blob of leaf
+ parent types.Hash // the hash of parent node
+}
+
+// committer is the tool used for the trie Commit operation. The committer will
+// capture all dirty nodes during the commit process and keep them cached in
+// insertion order.
+type committer struct {
+ nodes *NodeSet
+ tracer *tracer
+ collectLeaf bool
+}
+
+// newCommitter creates a new committer or picks one from the pool.
+func newCommitter(owner types.Hash, tracer *tracer, collectLeaf bool) *committer {
+ return &committer{
+ nodes: NewNodeSet(owner),
+ tracer: tracer,
+ collectLeaf: collectLeaf,
+ }
+}
+
+// Commit collapses a node down into a hash node and returns it along with
+// the modified nodeset.
+func (c *committer) Commit(n node) (hashNode, *NodeSet, error) {
+ h, err := c.commit(nil, n)
+ if err != nil {
+ return nil, nil, err
+ }
+ // Some nodes can be deleted from trie which can't be captured
+ // by committer itself. Iterate all deleted nodes tracked by
+ // tracer and marked them as deleted only if they are present
+ // in database previously.
+ c.tracer.markDeletions(c.nodes)
+
+ //nolint:forcetypeassert
+ return h.(hashNode), c.nodes, nil
+}
+
+// commit collapses a node down into a hash node and returns it.
+func (c *committer) commit(path []byte, n node) (node, error) {
+ // if this path is clean, use available cached data
+ hash, dirty := n.cache()
+ if hash != nil && !dirty {
+ return hash, nil
+ }
+ // Commit children, then parent, and remove the dirty flag.
+ switch cn := n.(type) {
+ case *shortNode:
+ // Commit child
+ collapsed := cn.copy()
+
+ // If the child is fullNode, recursively commit,
+ // otherwise it can only be hashNode or valueNode.
+ if _, ok := cn.Val.(*fullNode); ok {
+ childV, err := c.commit(append(path, cn.Key...), cn.Val)
+ if err != nil {
+ return nil, err
+ }
+
+ collapsed.Val = childV
+ }
+ // The key needs to be copied, since we're adding it to the
+ // modified nodeset.
+ collapsed.Key = hexToCompact(cn.Key)
+
+ hashedNode := c.store(path, collapsed)
+ if hn, ok := hashedNode.(hashNode); ok {
+ return hn, nil
+ }
+ // The short node now is embedded in its parent. Mark the node as
+ // deleted if it's present in database previously. It's equivalent
+ // as deletion from database's perspective.
+ if prev := c.tracer.getPrev(path); len(prev) != 0 {
+ c.nodes.markDeleted(path, prev)
+ }
+
+ return collapsed, nil
+ case *fullNode:
+ hashedKids, err := c.commitChildren(path, cn)
+ if err != nil {
+ return nil, err
+ }
+
+ collapsed := cn.copy()
+ collapsed.Children = hashedKids
+
+ hashedNode := c.store(path, collapsed)
+ if hn, ok := hashedNode.(hashNode); ok {
+ return hn, nil
+ }
+ // The full node now is embedded in its parent. Mark the node as
+ // deleted if it's present in database previously. It's equivalent
+ // as deletion from database's perspective.
+ if prev := c.tracer.getPrev(path); len(prev) != 0 {
+ c.nodes.markDeleted(path, prev)
+ }
+
+ return collapsed, nil
+ case hashNode:
+ return cn, nil
+ default:
+ // nil, valuenode shouldn't be committed
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// commitChildren commits the children of the given fullnode
+func (c *committer) commitChildren(path []byte, n *fullNode) ([17]node, error) {
+ var children [17]node
+
+ for i := 0; i < 16; i++ {
+ child := n.Children[i]
+ if child == nil {
+ continue
+ }
+ // If it's the hashed child, save the hash value directly.
+ // Note: it's impossible that the child in range [0, 15]
+ // is a valueNode.
+ if hn, ok := child.(hashNode); ok {
+ children[i] = hn
+
+ continue
+ }
+ // Commit the child recursively and store the "hashed" value.
+ // Note the returned node can be some embedded nodes, so it's
+ // possible the type is not hashNode.
+ hashed, err := c.commit(append(path, byte(i)), child)
+ if err != nil {
+ return children, err
+ }
+
+ children[i] = hashed
+ }
+ // For the 17th child, it's possible the type is valuenode.
+ if n.Children[16] != nil {
+ children[16] = n.Children[16]
+ }
+
+ return children, nil
+}
+
+// store hashes the node n and adds it to the modified nodeset. If leaf collection
+// is enabled, leaf nodes will be tracked in the modified nodeset as well.
+func (c *committer) store(path []byte, n node) node {
+ // Larger nodes are replaced by their hash and stored in the database.
+ var hash, _ = n.cache()
+
+ // This was not generated - must be a small node stored in the parent.
+ // In theory, we should check if the node is leaf here (embedded node
+ // usually is leaf node). But small value (less than 32bytes) is not
+ // our target (leaves in account trie only).
+ if hash == nil {
+ return n
+ }
+ // We have the hash already, estimate the RLP encoding-size of the node.
+ // The size is used for mem tracking, does not need to be exact
+ var (
+ size = estimateSize(n)
+ nhash = types.BytesToHash(hash)
+ mnode = &memoryNode{
+ hash: nhash,
+ node: simplifyNode(n),
+ size: uint16(size),
+ }
+ )
+ // Collect the dirty node to nodeset for return.
+ c.nodes.markUpdated(path, mnode, c.tracer.getPrev(path))
+
+ // Collect the corresponding leaf node if it's required. We don't check
+ // full node since it's impossible to store value in fullNode. The key
+ // length of leaves should be exactly same.
+ if c.collectLeaf {
+ if sn, ok := n.(*shortNode); ok {
+ if val, ok := sn.Val.(valueNode); ok {
+ c.nodes.addLeaf(&leaf{blob: val, parent: nhash})
+ }
+ }
+ }
+
+ return hash
+}
+
+// estimateSize estimates the size of an rlp-encoded node, without actually
+// rlp-encoding it (zero allocs). This method has been experimentally tried, and with a trie
+// with 1000 leaves, the only errors above 1% are on small shortnodes, where this
+// method overestimates by 2 or 3 bytes (e.g. 37 instead of 35)
+func estimateSize(n node) int {
+ switch n := n.(type) {
+ case *shortNode:
+ // A short node contains a compacted key, and a value.
+ return 3 + len(n.Key) + estimateSize(n.Val)
+ case *fullNode:
+ // A full node contains up to 16 hashes (some nils), and a key
+ s := 3
+
+ for i := 0; i < 16; i++ {
+ if child := n.Children[i]; child != nil {
+ s += estimateSize(child)
+ } else {
+ s++
+ }
+ }
+
+ return s
+ case valueNode:
+ return 1 + len(n)
+ case hashNode:
+ return 1 + len(n)
+ default:
+ panic(fmt.Sprintf("node type %T", n))
+ }
+}
diff --git a/trie/database.go b/trie/database.go
new file mode 100644
index 0000000000..09d7b9b775
--- /dev/null
+++ b/trie/database.go
@@ -0,0 +1,940 @@
+// Copyright 2018 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "runtime"
+ "sync"
+ "time"
+
+ "github.com/VictoriaMetrics/fastcache"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Database is an intermediate write layer between the trie data structures and
+// the disk database. The aim is to accumulate trie writes in-memory and only
+// periodically flush a couple tries to disk, garbage collecting the remainder.
+//
+// Note, the trie Database is **not** thread safe in its mutations, but it **is**
+// thread safe in providing individual, independent node access. The rationale
+// behind this split design is to provide read access to RPC handlers and sync
+// servers even while the trie is executing expensive garbage collection.
+type Database struct {
+ diskdb kvdb.Database // Persistent storage for matured trie nodes
+
+ cleans *fastcache.Cache // GC friendly memory cache of clean node RLPs
+ dirties map[types.Hash]*cachedNode // Data and references relationships of dirty trie nodes
+ oldest types.Hash // Oldest tracked node, flush-list head
+ newest types.Hash // Newest tracked node, flush-list tail
+
+ gctime time.Duration // Time spent on garbage collection since last commit
+ gcnodes uint64 // Nodes garbage collected since last commit
+ gcsize types.StorageSize // Data storage garbage collected since last commit
+
+ flushtime time.Duration // Time spent on data flushing since last commit
+ flushnodes uint64 // Nodes flushed since last commit
+ flushsize types.StorageSize // Data storage flushed since last commit
+
+ dirtiesSize types.StorageSize // Storage size of the dirty node cache (exc. metadata)
+ childrenSize types.StorageSize // Storage size of the external children tracking
+
+ lock sync.RWMutex
+ logger Logger
+}
+
+// rawNode is a simple binary blob used to differentiate between collapsed trie
+// nodes and already encoded RLP binary blobs (while at the same time store them
+// in the same cache fields).
+type rawNode []byte
+
+func (n rawNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+func (n rawNode) EncodeRLP(w io.Writer) error {
+ _, err := w.Write(n)
+
+ return err
+}
+
+// rawFullNode represents only the useful data content of a full node, with the
+// caches and flags stripped out to minimize its data storage. This type honors
+// the same RLP encoding as the original parent.
+type rawFullNode [17]node
+
+func (n rawFullNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawFullNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+func (n rawFullNode) EncodeRLP(w io.Writer) error {
+ eb := rlp.NewEncoderBuffer(w)
+ n.encode(eb)
+
+ return eb.Flush()
+}
+
+// rawShortNode represents only the useful data content of a short node, with the
+// caches and flags stripped out to minimize its data storage. This type honors
+// the same RLP encoding as the original parent.
+type rawShortNode struct {
+ Key []byte
+ Val node
+}
+
+func (n rawShortNode) cache() (hashNode, bool) { panic("this should never end up in a live trie") }
+func (n rawShortNode) fstring(ind string) string { panic("this should never end up in a live trie") }
+
+// cachedNode is all the information we know about a single cached trie node
+// in the memory database write layer.
+type cachedNode struct {
+ node node // Cached collapsed trie node, or raw rlp data
+ size uint16 // Byte size of the useful cached data
+
+ parents uint32 // Number of live nodes referencing this one
+ children map[types.Hash]uint16 // External children referenced by this node
+
+ flushPrev types.Hash // Previous node in the flush-list
+ flushNext types.Hash // Next node in the flush-list
+}
+
+// cachedNodeSize is the raw size of a cachedNode data structure without any
+// node data included. It's an approximate size, but should be a lot better
+// than not counting them.
+var cachedNodeSize = int(reflect.TypeOf(cachedNode{}).Size())
+
+// cachedNodeChildrenSize is the raw size of an initialized but empty external
+// reference map.
+const cachedNodeChildrenSize = 48
+
+// rlp returns the raw rlp encoded blob of the cached trie node, either directly
+// from the cache, or by regenerating it from the collapsed node.
+func (n *cachedNode) rlp() []byte {
+ if node, ok := n.node.(rawNode); ok {
+ return node
+ }
+
+ return nodeToBytes(n.node)
+}
+
+// obj returns the decoded and expanded trie node, either directly from the cache,
+// or by regenerating it from the rlp encoded blob.
+func (n *cachedNode) obj(hash types.Hash) node {
+ if node, ok := n.node.(rawNode); ok {
+ // The raw-blob format nodes are loaded either from the
+ // clean cache or the database, they are all in their own
+ // copy and safe to use unsafe decoder.
+ return mustDecodeNodeUnsafe(hash[:], node)
+ }
+
+ return expandNode(hash[:], n.node)
+}
+
+// forChilds invokes the callback for all the tracked children of this node,
+// both the implicit ones from inside the node as well as the explicit ones
+// from outside the node.
+func (n *cachedNode) forChilds(onChild func(hash types.Hash)) {
+ for child := range n.children {
+ onChild(child)
+ }
+
+ if _, ok := n.node.(rawNode); !ok {
+ forGatherChildren(n.node, onChild)
+ }
+}
+
+// forGatherChildren traverses the node hierarchy of a collapsed storage node and
+// invokes the callback for all the hashnode children.
+func forGatherChildren(n node, onChild func(hash types.Hash)) {
+ switch n := n.(type) {
+ case *rawShortNode:
+ forGatherChildren(n.Val, onChild)
+ case rawFullNode:
+ for i := 0; i < 16; i++ {
+ forGatherChildren(n[i], onChild)
+ }
+ case hashNode:
+ onChild(types.BytesToHash(n))
+ case valueNode, nil, rawNode:
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
+
+// simplifyNode traverses the hierarchy of an expanded memory node and discards
+// all the internal caches, returning a node that only contains the raw data.
+func simplifyNode(n node) node {
+ switch n := n.(type) {
+ case *shortNode:
+ // Short nodes discard the flags and cascade
+ return &rawShortNode{Key: n.Key, Val: simplifyNode(n.Val)}
+ case *fullNode:
+ // Full nodes discard the flags and cascade
+ node := rawFullNode(n.Children)
+
+ for i := 0; i < len(node); i++ {
+ if node[i] != nil {
+ node[i] = simplifyNode(node[i])
+ }
+ }
+
+ return node
+ case valueNode, hashNode, rawNode:
+ return n
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
+
+// expandNode traverses the node hierarchy of a collapsed storage node and converts
+// all fields and keys into expanded memory form.
+func expandNode(hash hashNode, n node) node {
+ switch n := n.(type) {
+ case *rawShortNode:
+ // Short nodes need key and child expansion
+ return &shortNode{
+ Key: compactToHex(n.Key),
+ Val: expandNode(nil, n.Val),
+ flags: nodeFlag{
+ hash: hash,
+ },
+ }
+ case rawFullNode:
+ // Full nodes need child expansion
+ node := &fullNode{
+ flags: nodeFlag{
+ hash: hash,
+ },
+ }
+
+ for i := 0; i < len(node.Children); i++ {
+ if n[i] != nil {
+ node.Children[i] = expandNode(nil, n[i])
+ }
+ }
+
+ return node
+ case valueNode, hashNode:
+ return n
+ default:
+ panic(fmt.Sprintf("unknown node type: %T", n))
+ }
+}
+
+// Config defines all necessary options for database.
+type Config struct {
+ Cache int // Memory allowance (MB) to use for caching trie nodes in memory
+ Journal string // Journal of clean cache to survive node restarts
+}
+
+// NewDatabase creates a new trie database to store ephemeral trie content before
+// its written out to disk or garbage collected. No read cache is created, so all
+// data retrievals will hit the underlying disk database.
+func NewDatabase(diskdb kvdb.Database, logger Logger) *Database {
+ return NewDatabaseWithConfig(diskdb, nil, logger)
+}
+
+// NewDatabaseWithConfig creates a new trie database to store ephemeral trie content
+// before its written out to disk or garbage collected. It also acts as a read cache
+// for nodes loaded from disk.
+func NewDatabaseWithConfig(diskdb kvdb.Database, config *Config, logger Logger) *Database {
+ var cleans *fastcache.Cache
+
+ if config != nil && config.Cache > 0 {
+ if config.Journal == "" {
+ cleans = fastcache.New(config.Cache * 1024 * 1024)
+ } else {
+ cleans = fastcache.LoadFromFileOrNew(config.Journal, config.Cache*1024*1024)
+ }
+ }
+
+ db := &Database{
+ diskdb: diskdb,
+ cleans: cleans,
+ dirties: map[types.Hash]*cachedNode{{}: {
+ children: make(map[types.Hash]uint16),
+ }},
+ logger: logger,
+ }
+
+ return db
+}
+
+// insert inserts a simplified trie node into the memory database.
+// All nodes inserted by this function will be reference tracked
+// and in theory should only used for **trie nodes** insertion.
+func (db *Database) insert(hash types.Hash, size int, node node) {
+ // If the node's already cached, skip
+ if _, ok := db.dirties[hash]; ok {
+ return
+ }
+
+ // Create the cached entry for this node
+ entry := &cachedNode{
+ node: node,
+ size: uint16(size),
+ flushPrev: db.newest,
+ }
+
+ entry.forChilds(func(child types.Hash) {
+ if c := db.dirties[child]; c != nil {
+ c.parents++
+ }
+ })
+
+ db.dirties[hash] = entry
+
+ // Update the flush-list endpoints
+ if db.oldest == (types.Hash{}) {
+ db.oldest, db.newest = hash, hash
+ } else {
+ db.dirties[db.newest].flushNext, db.newest = hash, hash
+ }
+
+ db.dirtiesSize += types.StorageSize(types.HashLength + entry.size)
+}
+
+// node retrieves a cached trie node from memory, or returns nil if none can be
+// found in the memory cache.
+func (db *Database) node(hash types.Hash) node {
+ // Retrieve the node from the clean cache if available
+ if db.cleans != nil {
+ if enc := db.cleans.Get(nil, hash[:]); enc != nil {
+ // The returned value from cache is in its own copy,
+ // safe to use mustDecodeNodeUnsafe for decoding.
+ return mustDecodeNodeUnsafe(hash[:], enc)
+ }
+ }
+
+ // Retrieve the node from the dirty cache if available
+ db.lock.RLock()
+ dirty := db.dirties[hash]
+ db.lock.RUnlock()
+
+ if dirty != nil {
+ return dirty.obj(hash)
+ }
+
+ // Content unavailable in memory, attempt to retrieve from disk
+ enc, exists, err := db.diskdb.Get(hash.Bytes())
+ if err != nil || !exists || enc == nil {
+ return nil
+ }
+
+ if db.cleans != nil {
+ db.cleans.Set(hash[:], enc)
+ }
+
+ // The returned value from database is in its own copy,
+ // safe to use mustDecodeNodeUnsafe for decoding.
+ return mustDecodeNodeUnsafe(hash[:], enc)
+}
+
+// Node retrieves an encoded cached trie node from memory. If it cannot be found
+// cached, the method queries the persistent database for the content.
+func (db *Database) Node(hash types.Hash) ([]byte, error) {
+ // It doesn't make sense to retrieve the metaroot
+ if hash == (types.Hash{}) {
+ return nil, errors.New("not found")
+ }
+
+ // Retrieve the node from the clean cache if available
+ if db.cleans != nil {
+ if enc := db.cleans.Get(nil, hash[:]); enc != nil {
+ return enc, nil
+ }
+ }
+
+ // Retrieve the node from the dirty cache if available
+ db.lock.RLock()
+ dirty := db.dirties[hash]
+ db.lock.RUnlock()
+
+ if dirty != nil {
+ return dirty.rlp(), nil
+ }
+
+ // Content unavailable in memory, attempt to retrieve from disk
+ enc := rawdb.ReadTrieNode(db.diskdb, hash)
+
+ if len(enc) != 0 {
+ if db.cleans != nil {
+ db.cleans.Set(hash[:], enc)
+ }
+
+ return enc, nil
+ }
+
+ return nil, errors.New("not found")
+}
+
+// Nodes retrieves the hashes of all the nodes cached within the memory database.
+// This method is extremely expensive and should only be used to validate internal
+// states in test code.
+func (db *Database) Nodes() []types.Hash {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ var hashes = make([]types.Hash, 0, len(db.dirties))
+
+ for hash := range db.dirties {
+ if hash != (types.Hash{}) { // Special case for "root" references/nodes
+ hashes = append(hashes, hash)
+ }
+ }
+
+ return hashes
+}
+
+// Reference adds a new reference from a parent node to a child node.
+// This function is used to add reference between internal trie node
+// and external node(e.g. storage trie root), all internal trie nodes
+// are referenced together by database itself.
+func (db *Database) Reference(child types.Hash, parent types.Hash) {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ db.reference(child, parent)
+}
+
+// reference is the private locked version of Reference.
+func (db *Database) reference(child types.Hash, parent types.Hash) {
+ // If the node does not exist, it's a node pulled from disk, skip
+ node, ok := db.dirties[child]
+ if !ok {
+ return
+ }
+ // If the reference already exists, only duplicate for roots
+ if db.dirties[parent].children == nil {
+ db.dirties[parent].children = make(map[types.Hash]uint16)
+ db.childrenSize += cachedNodeChildrenSize
+ } else if _, ok = db.dirties[parent].children[child]; ok && parent != (types.Hash{}) {
+ return
+ }
+
+ node.parents++
+ db.dirties[parent].children[child]++
+
+ if db.dirties[parent].children[child] == 1 {
+ db.childrenSize += types.HashLength + 2 // uint16 counter
+ }
+}
+
+// Dereference removes an existing reference from a root node.
+func (db *Database) Dereference(root types.Hash) {
+ // Sanity check to ensure that the meta-root is not removed
+ if root == (types.Hash{}) {
+ db.logger.Error("Attempted to dereference the trie cache meta root")
+
+ return
+ }
+
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
+ db.dereference(root, types.Hash{})
+
+ db.gcnodes += uint64(nodes - len(db.dirties))
+ db.gcsize += storage - db.dirtiesSize
+ db.gctime += time.Since(start)
+
+ db.logger.Debug("Dereferenced trie from memory database",
+ "nodes", nodes-len(db.dirties),
+ "size", storage-db.dirtiesSize,
+ "time", time.Since(start),
+ "gcnodes", db.gcnodes,
+ "gcsize", db.gcsize,
+ "gctime", db.gctime,
+ "livenodes", len(db.dirties),
+ "livesize", db.dirtiesSize,
+ )
+}
+
+// dereference is the private locked version of Dereference.
+func (db *Database) dereference(child types.Hash, parent types.Hash) {
+ // Dereference the parent-child
+ node := db.dirties[parent]
+
+ if node.children != nil && node.children[child] > 0 {
+ node.children[child]--
+
+ if node.children[child] == 0 {
+ delete(node.children, child)
+
+ db.childrenSize -= (types.HashLength + 2) // uint16 counter
+ }
+ }
+
+ // If the child does not exist, it's a previously committed node.
+ node, ok := db.dirties[child]
+ if !ok {
+ return
+ }
+
+ // If there are no more references to the child, delete it and cascade
+ if node.parents > 0 {
+ // This is a special cornercase where a node loaded from disk (i.e. not in the
+ // memcache any more) gets reinjected as a new node (short node split into full,
+ // then reverted into short), causing a cached node to have no parents. That is
+ // no problem in itself, but don't make maxint parents out of it.
+ node.parents--
+ }
+
+ if node.parents == 0 {
+ // Remove the node from the flush-list
+ switch child {
+ case db.oldest:
+ db.oldest = node.flushNext
+ db.dirties[node.flushNext].flushPrev = types.Hash{}
+ case db.newest:
+ db.newest = node.flushPrev
+ db.dirties[node.flushPrev].flushNext = types.Hash{}
+ default:
+ db.dirties[node.flushPrev].flushNext = node.flushNext
+ db.dirties[node.flushNext].flushPrev = node.flushPrev
+ }
+ // Dereference all children and delete the node
+ node.forChilds(func(hash types.Hash) {
+ db.dereference(hash, child)
+ })
+
+ delete(db.dirties, child)
+ db.dirtiesSize -= types.StorageSize(types.HashLength + int(node.size))
+
+ if node.children != nil {
+ db.childrenSize -= cachedNodeChildrenSize
+ }
+ }
+}
+
+// Cap iteratively flushes old but still referenced trie nodes until the total
+// memory usage goes below the given threshold.
+//
+// Note, this method is a non-synchronized mutator. It is unsafe to call this
+// concurrently with other mutators.
+func (db *Database) Cap(limit types.StorageSize) error {
+ // Create a database batch to flush persistent data out. It is important that
+ // outside code doesn't see an inconsistent state (referenced data removed from
+ // memory cache during commit but not yet in persistent storage). This is ensured
+ // by only uncaching existing data when the database write finalizes.
+ nodes, storage, start := len(db.dirties), db.dirtiesSize, time.Now()
+ batch := db.diskdb.NewBatch()
+
+ // db.dirtiesSize only contains the useful data in the cache, but when reporting
+ // the total memory consumption, the maintenance metadata is also needed to be
+ // counted.
+ size := db.dirtiesSize + types.StorageSize((len(db.dirties)-1)*cachedNodeSize)
+ size += db.childrenSize - types.StorageSize(len(db.dirties[types.Hash{}].children)*(types.HashLength+2))
+
+ // Keep committing nodes from the flush-list until we're below allowance
+ oldest := db.oldest
+ for size > limit && oldest != (types.Hash{}) {
+ // Fetch the oldest referenced node and push into the batch
+ node := db.dirties[oldest]
+ rawdb.WriteTrieNode(batch, oldest, node.rlp())
+
+ // If we exceeded the ideal batch size, commit and reset
+ if batch.ValueSize() >= kvdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ db.logger.Error("Failed to write flush list to disk", "err", err)
+
+ return err
+ }
+
+ batch.Reset()
+ }
+
+ // Iterate to the next flush item, or abort if the size cap was achieved. Size
+ // is the total size, including the useful cached data (hash -> blob), the
+ // cache item metadata, as well as external children mappings.
+ size -= types.StorageSize(types.HashLength + int(node.size) + cachedNodeSize)
+
+ if node.children != nil {
+ size -= types.StorageSize(cachedNodeChildrenSize + len(node.children)*(types.HashLength+2))
+ }
+
+ oldest = node.flushNext
+ }
+
+ // Flush out any remainder data from the last batch
+ if err := batch.Write(); err != nil {
+ db.logger.Error("Failed to write flush list to disk", "err", err)
+
+ return err
+ }
+
+ // Write successful, clear out the flushed data
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ for db.oldest != oldest {
+ node := db.dirties[db.oldest]
+ delete(db.dirties, db.oldest)
+ db.oldest = node.flushNext
+
+ db.dirtiesSize -= types.StorageSize(types.HashLength + int(node.size))
+
+ if node.children != nil {
+ db.childrenSize -= types.StorageSize(cachedNodeChildrenSize + len(node.children)*(types.HashLength+2))
+ }
+ }
+
+ if db.oldest != (types.Hash{}) {
+ db.dirties[db.oldest].flushPrev = types.Hash{}
+ }
+
+ db.flushnodes += uint64(nodes - len(db.dirties))
+ db.flushsize += storage - db.dirtiesSize
+ db.flushtime += time.Since(start)
+
+ db.logger.Debug("Persisted nodes from memory database",
+ "nodes", nodes-len(db.dirties),
+ "size", storage-db.dirtiesSize,
+ "time", time.Since(start),
+ "flushnodes", db.flushnodes,
+ "flushsize", db.flushsize,
+ "flushtime", db.flushtime,
+ "livenodes", len(db.dirties),
+ "livesize", db.dirtiesSize,
+ )
+
+ return nil
+}
+
+// Commit iterates over all the children of a particular node, writes them out
+// to disk, forcefully tearing down all references in both directions. As a side
+// effect, all pre-images accumulated up to this point are also written.
+//
+// Note, this method is a non-synchronized mutator. It is unsafe to call this
+// concurrently with other mutators.
+func (db *Database) Commit(node types.Hash, report bool, callback func(types.Hash)) error {
+ // Create a database batch to flush persistent data out. It is important that
+ // outside code doesn't see an inconsistent state (referenced data removed from
+ // memory cache during commit but not yet in persistent storage). This is ensured
+ // by only uncaching existing data when the database write finalizes.
+ start := time.Now()
+ batch := db.diskdb.NewBatch()
+
+ // Move the trie itself into the batch, flushing if enough data is accumulated
+ nodes, storage := len(db.dirties), db.dirtiesSize
+
+ uncacher := &cleaner{db}
+
+ if err := db.commit(node, batch, uncacher, callback); err != nil {
+ db.logger.Error("Failed to commit trie from trie database", "err", err)
+
+ return err
+ }
+
+ // Trie mostly committed to disk, flush any batch leftovers
+ if err := batch.Write(); err != nil {
+ db.logger.Error("Failed to write trie to disk", "err", err)
+
+ return err
+ }
+
+ // Uncache any leftovers in the last batch
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ if err := batch.Replay(uncacher); err != nil {
+ return err
+ }
+
+ batch.Reset()
+
+ logger := db.logger.Info
+ if !report {
+ logger = db.logger.Debug
+ }
+
+ logger("Persisted trie from memory database",
+ "nodes", nodes-len(db.dirties)+int(db.flushnodes),
+ "size", storage-db.dirtiesSize+db.flushsize,
+ "time", time.Since(start)+db.flushtime,
+ "gcnodes", db.gcnodes,
+ "gcsize", db.gcsize,
+ "gctime", db.gctime,
+ "livenodes", len(db.dirties),
+ "livesize", db.dirtiesSize,
+ )
+
+ // Reset the garbage collection statistics
+ db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
+ db.flushnodes, db.flushsize, db.flushtime = 0, 0, 0
+
+ return nil
+}
+
+// commit is the private locked version of Commit.
+func (db *Database) commit(hash types.Hash, batch kvdb.Batch, uncacher *cleaner, callback func(types.Hash)) error {
+ // If the node does not exist, it's a previously committed node
+ node, ok := db.dirties[hash]
+ if !ok {
+ return nil
+ }
+
+ var err error
+
+ node.forChilds(func(child types.Hash) {
+ if err == nil {
+ err = db.commit(child, batch, uncacher, callback)
+ }
+ })
+
+ if err != nil {
+ return err
+ }
+
+ // If we've reached an optimal batch size, commit and start over
+ rawdb.WriteTrieNode(batch, hash, node.rlp())
+
+ if callback != nil {
+ callback(hash)
+ }
+
+ if batch.ValueSize() >= kvdb.IdealBatchSize {
+ if err := batch.Write(); err != nil {
+ return err
+ }
+
+ db.lock.Lock()
+
+ err := batch.Replay(uncacher)
+
+ batch.Reset()
+ db.lock.Unlock()
+
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// cleaner is a database batch replayer that takes a batch of write operations
+// and cleans up the trie database from anything written to disk.
+type cleaner struct {
+ db *Database
+}
+
+// Set reacts to database writes and implements dirty data uncaching. This is the
+// post-processing step of a commit operation where the already persisted trie is
+// removed from the dirty cache and moved into the clean cache. The reason behind
+// the two-phase commit is to ensure data availability while moving from memory
+// to disk.
+func (c *cleaner) Set(key []byte, rlp []byte) error {
+ hash := types.BytesToHash(key)
+
+ // If the node does not exist, we're done on this path
+ node, ok := c.db.dirties[hash]
+ if !ok {
+ return nil
+ }
+
+ // Node still exists, remove it from the flush-list
+ switch hash {
+ case c.db.oldest:
+ c.db.oldest = node.flushNext
+ c.db.dirties[node.flushNext].flushPrev = types.Hash{}
+ case c.db.newest:
+ c.db.newest = node.flushPrev
+ c.db.dirties[node.flushPrev].flushNext = types.Hash{}
+ default:
+ c.db.dirties[node.flushPrev].flushNext = node.flushNext
+ c.db.dirties[node.flushNext].flushPrev = node.flushPrev
+ }
+
+ // Remove the node from the dirty cache
+ delete(c.db.dirties, hash)
+ c.db.dirtiesSize -= types.StorageSize(types.HashLength + int(node.size))
+
+ if node.children != nil {
+ c.db.childrenSize -= types.StorageSize(cachedNodeChildrenSize + len(node.children)*(types.HashLength+2))
+ }
+
+ // Move the flushed node into the clean cache to prevent insta-reloads
+ if c.db.cleans != nil {
+ c.db.cleans.Set(hash[:], rlp)
+ }
+
+ return nil
+}
+
+func (c *cleaner) Delete(key []byte) error {
+ panic("not implemented")
+}
+
+// Update inserts the dirty nodes in provided nodeset into database and
+// link the account trie with multiple storage tries if necessary.
+func (db *Database) Update(nodes *MergedNodeSet) error {
+ db.lock.Lock()
+ defer db.lock.Unlock()
+
+ // Insert dirty nodes into the database. In the same tree, it must be
+ // ensured that children are inserted first, then parent so that children
+ // can be linked with their parent correctly.
+ //
+ // Note, the storage tries must be flushed before the account trie to
+ // retain the invariant that children go into the dirty cache first.
+ order := make([]types.Hash, 0)
+
+ for owner := range nodes.sets {
+ if owner == (types.Hash{}) {
+ continue
+ }
+
+ order = append(order, owner)
+ }
+
+ if _, ok := nodes.sets[types.Hash{}]; ok {
+ order = append(order, types.Hash{})
+ }
+
+ for _, owner := range order {
+ subset := nodes.sets[owner]
+
+ for _, path := range subset.updates.order {
+ n, ok := subset.updates.nodes[path]
+ if !ok {
+ return fmt.Errorf("missing node %x %v", owner, path)
+ }
+
+ db.insert(n.hash, int(n.size), n.node)
+ }
+ }
+
+ // Link up the account trie and storage trie if the node points
+ // to an account trie leaf.
+ if set, present := nodes.sets[types.Hash{}]; present {
+ for _, n := range set.leaves {
+ var account stypes.Account
+ if err := account.UnmarshalRlp(n.blob); err != nil {
+ return err
+ }
+
+ if account.StorageRoot != types.EmptyRootHash {
+ db.reference(account.StorageRoot, n.parent)
+ }
+ }
+ }
+
+ return nil
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+func (db *Database) Size() (types.StorageSize, types.StorageSize) {
+ db.lock.RLock()
+ defer db.lock.RUnlock()
+
+ // db.dirtiesSize only contains the useful data in the cache, but when reporting
+ // the total memory consumption, the maintenance metadata is also needed to be
+ // counted.
+ var (
+ metadataSize = types.StorageSize((len(db.dirties) - 1) * cachedNodeSize)
+ metarootRefs = types.StorageSize(len(db.dirties[types.Hash{}].children) * (types.HashLength + 2))
+ )
+
+ return db.dirtiesSize + db.childrenSize + metadataSize - metarootRefs, 0
+}
+
+// GetReader retrieves a node reader belonging to the given state root.
+func (db *Database) GetReader(root types.Hash) Reader {
+ return newHashReader(db)
+}
+
+// hashReader is reader of hashDatabase which implements the Reader interface.
+type hashReader struct {
+ db *Database
+}
+
+// newHashReader initializes the hash reader.
+func newHashReader(db *Database) *hashReader {
+ return &hashReader{db: db}
+}
+
+// Node retrieves the trie node with the given node hash.
+// No error will be returned if the node is not found.
+func (reader *hashReader) Node(_ types.Hash, _ []byte, hash types.Hash) (node, error) {
+ return reader.db.node(hash), nil
+}
+
+// NodeBlob retrieves the RLP-encoded trie node blob with the given node hash.
+// No error will be returned if the node is not found.
+func (reader *hashReader) NodeBlob(_ types.Hash, _ []byte, hash types.Hash) ([]byte, error) {
+ blob, _ := reader.db.Node(hash)
+
+ return blob, nil
+}
+
+// saveCache saves clean state cache to given directory path
+// using specified CPU cores.
+func (db *Database) saveCache(dir string, threads int) error {
+ if db.cleans == nil {
+ return nil
+ }
+
+ db.logger.Info("Writing clean trie cache to disk", "path", dir, "threads", threads)
+
+ start := time.Now()
+
+ err := db.cleans.SaveToFileConcurrent(dir, threads)
+ if err != nil {
+ db.logger.Error("Failed to persist clean trie cache", "error", err)
+
+ return err
+ }
+
+ db.logger.Info("Persisted the clean trie cache", "path", dir, "elapsed", types.PrettyDuration(time.Since(start)))
+
+ return nil
+}
+
+// SaveCache atomically saves fast cache data to the given dir using all
+// available CPU cores.
+func (db *Database) SaveCache(dir string) error {
+ return db.saveCache(dir, runtime.GOMAXPROCS(0))
+}
+
+// SaveCachePeriodically atomically saves fast cache data to the given dir with
+// the specified interval. All dump operation will only use a single CPU core.
+func (db *Database) SaveCachePeriodically(dir string, interval time.Duration, stopCh <-chan struct{}) {
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ db.saveCache(dir, 1)
+ case <-stopCh:
+ return
+ }
+ }
+}
+
+// Scheme returns the node scheme used in the database.
+func (db *Database) Scheme() NodeScheme {
+ return &hashScheme{}
+}
diff --git a/trie/database_test.go b/trie/database_test.go
new file mode 100644
index 0000000000..62f56e3462
--- /dev/null
+++ b/trie/database_test.go
@@ -0,0 +1,34 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+// Tests that the trie database returns a missing trie node error if attempting
+// to retrieve the meta root.
+func TestDatabaseMetarootFetch(t *testing.T) {
+ db := NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger())
+ if _, err := db.Node(types.Hash{}); err == nil {
+ t.Fatalf("metaroot retrieval succeeded")
+ }
+}
diff --git a/trie/encoding.go b/trie/encoding.go
new file mode 100644
index 0000000000..242a59b126
--- /dev/null
+++ b/trie/encoding.go
@@ -0,0 +1,168 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+// Trie keys are dealt with in three distinct encodings:
+//
+// KEYBYTES encoding contains the actual key and nothing else. This encoding is the
+// input to most API functions.
+//
+// HEX encoding contains one byte for each nibble of the key and an optional trailing
+// 'terminator' byte of value 0x10 which indicates whether or not the node at the key
+// contains a value. Hex key encoding is used for nodes loaded in memory because it's
+// convenient to access.
+//
+// COMPACT encoding is defined by the Ethereum Yellow Paper (it's called "hex prefix
+// encoding" there) and contains the bytes of the key and a flag. The high nibble of the
+// first byte contains the flag; the lowest bit encoding the oddness of the length and
+// the second-lowest encoding whether the node at the key is a value node. The low nibble
+// of the first byte is zero in the case of an even number of nibbles and the first nibble
+// in the case of an odd number. All remaining nibbles (now an even number) fit properly
+// into the remaining bytes. Compact encoding is used for nodes stored on disk.
+
+func hexToCompact(hex []byte) []byte {
+ terminator := byte(0)
+
+ if hasTerm(hex) {
+ terminator = 1
+ hex = hex[:len(hex)-1]
+ }
+
+ buf := make([]byte, len(hex)/2+1)
+ buf[0] = terminator << 5 // the flag byte
+
+ if len(hex)&1 == 1 {
+ buf[0] |= 1 << 4 // odd flag
+ buf[0] |= hex[0] // first nibble is contained in the first byte
+ hex = hex[1:]
+ }
+
+ decodeNibbles(hex, buf[1:])
+
+ return buf
+}
+
+// hexToCompactInPlace places the compact key in input buffer, returning the length
+// needed for the representation
+func hexToCompactInPlace(hex []byte) int {
+ var (
+ hexLen = len(hex) // length of the hex input
+ firstByte = byte(0)
+ )
+ // Check if we have a terminator there
+ if hexLen > 0 && hex[hexLen-1] == 16 {
+ firstByte = 1 << 5
+ hexLen-- // last part was the terminator, ignore that
+ }
+
+ var (
+ binLen = hexLen/2 + 1
+ ni = 0 // index in hex
+ bi = 1 // index in bin (compact)
+ )
+
+ if hexLen&1 == 1 {
+ firstByte |= 1 << 4 // odd flag
+ firstByte |= hex[0] // first nibble is contained in the first byte
+ ni++
+ }
+
+ for ; ni < hexLen; bi, ni = bi+1, ni+2 {
+ hex[bi] = hex[ni]<<4 | hex[ni+1]
+ }
+
+ hex[0] = firstByte
+
+ return binLen
+}
+
+func compactToHex(compact []byte) []byte {
+ if len(compact) == 0 {
+ return compact
+ }
+
+ base := keybytesToHex(compact)
+
+ // delete terminator flag
+ if base[0] < 2 {
+ base = base[:len(base)-1]
+ }
+
+ // apply odd flag
+ chop := 2 - base[0]&1
+
+ return base[chop:]
+}
+
+func keybytesToHex(str []byte) []byte {
+ l := len(str)*2 + 1
+
+ var nibbles = make([]byte, l)
+ for i, b := range str {
+ nibbles[i*2] = b / 16
+ nibbles[i*2+1] = b % 16
+ }
+
+ nibbles[l-1] = 16
+
+ return nibbles
+}
+
+// hexToKeybytes turns hex nibbles into key bytes.
+// This can only be used for keys of even length.
+func hexToKeybytes(hex []byte) []byte {
+ if hasTerm(hex) {
+ hex = hex[:len(hex)-1]
+ }
+
+ if len(hex)&1 != 0 {
+ panic("can't convert hex key of odd length")
+ }
+
+ key := make([]byte, len(hex)/2)
+ decodeNibbles(hex, key)
+
+ return key
+}
+
+func decodeNibbles(nibbles []byte, bytes []byte) {
+ for bi, ni := 0, 0; ni < len(nibbles); bi, ni = bi+1, ni+2 {
+ bytes[bi] = nibbles[ni]<<4 | nibbles[ni+1]
+ }
+}
+
+// prefixLen returns the length of the common prefix of a and b.
+func prefixLen(a, b []byte) int {
+ var i, length = 0, len(a)
+
+ if len(b) < length {
+ length = len(b)
+ }
+
+ for ; i < length; i++ {
+ if a[i] != b[i] {
+ break
+ }
+ }
+
+ return i
+}
+
+// hasTerm returns whether a hex key has the terminator flag.
+func hasTerm(s []byte) bool {
+ return len(s) > 0 && s[len(s)-1] == 16
+}
diff --git a/trie/encoding_test.go b/trie/encoding_test.go
new file mode 100644
index 0000000000..16393313f7
--- /dev/null
+++ b/trie/encoding_test.go
@@ -0,0 +1,140 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "encoding/hex"
+ "math/rand"
+ "testing"
+)
+
+func TestHexCompact(t *testing.T) {
+ tests := []struct{ hex, compact []byte }{
+ // empty keys, with and without terminator.
+ {hex: []byte{}, compact: []byte{0x00}},
+ {hex: []byte{16}, compact: []byte{0x20}},
+ // odd length, no terminator
+ {hex: []byte{1, 2, 3, 4, 5}, compact: []byte{0x11, 0x23, 0x45}},
+ // even length, no terminator
+ {hex: []byte{0, 1, 2, 3, 4, 5}, compact: []byte{0x00, 0x01, 0x23, 0x45}},
+ // odd length, terminator
+ {hex: []byte{15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x3f, 0x1c, 0xb8}},
+ // even length, terminator
+ {hex: []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}, compact: []byte{0x20, 0x0f, 0x1c, 0xb8}},
+ }
+ for _, test := range tests {
+ if c := hexToCompact(test.hex); !bytes.Equal(c, test.compact) {
+ t.Errorf("hexToCompact(%x) -> %x, want %x", test.hex, c, test.compact)
+ }
+ if h := compactToHex(test.compact); !bytes.Equal(h, test.hex) {
+ t.Errorf("compactToHex(%x) -> %x, want %x", test.compact, h, test.hex)
+ }
+ }
+}
+
+func TestHexKeybytes(t *testing.T) {
+ tests := []struct{ key, hexIn, hexOut []byte }{
+ {key: []byte{}, hexIn: []byte{16}, hexOut: []byte{16}},
+ {key: []byte{}, hexIn: []byte{}, hexOut: []byte{16}},
+ {
+ key: []byte{0x12, 0x34, 0x56},
+ hexIn: []byte{1, 2, 3, 4, 5, 6, 16},
+ hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+ },
+ {
+ key: []byte{0x12, 0x34, 0x5},
+ hexIn: []byte{1, 2, 3, 4, 0, 5, 16},
+ hexOut: []byte{1, 2, 3, 4, 0, 5, 16},
+ },
+ {
+ key: []byte{0x12, 0x34, 0x56},
+ hexIn: []byte{1, 2, 3, 4, 5, 6},
+ hexOut: []byte{1, 2, 3, 4, 5, 6, 16},
+ },
+ }
+ for _, test := range tests {
+ if h := keybytesToHex(test.key); !bytes.Equal(h, test.hexOut) {
+ t.Errorf("keybytesToHex(%x) -> %x, want %x", test.key, h, test.hexOut)
+ }
+ if k := hexToKeybytes(test.hexIn); !bytes.Equal(k, test.key) {
+ t.Errorf("hexToKeybytes(%x) -> %x, want %x", test.hexIn, k, test.key)
+ }
+ }
+}
+
+func TestHexToCompactInPlace(t *testing.T) {
+ for i, keyS := range []string{
+ "00",
+ "060a040c0f000a090b040803010801010900080d090a0a0d0903000b10",
+ "10",
+ } {
+ hexBytes, _ := hex.DecodeString(keyS)
+ exp := hexToCompact(hexBytes)
+ sz := hexToCompactInPlace(hexBytes)
+ got := hexBytes[:sz]
+ if !bytes.Equal(exp, got) {
+ t.Fatalf("test %d: encoding err\ninp %v\ngot %x\nexp %x\n", i, keyS, got, exp)
+ }
+ }
+}
+
+func TestHexToCompactInPlaceRandom(t *testing.T) {
+ for i := 0; i < 10000; i++ {
+ l := rand.Intn(128)
+ key := make([]byte, l)
+ rand.Read(key)
+ hexBytes := keybytesToHex(key)
+ hexOrig := []byte(string(hexBytes))
+ exp := hexToCompact(hexBytes)
+ sz := hexToCompactInPlace(hexBytes)
+ got := hexBytes[:sz]
+
+ if !bytes.Equal(exp, got) {
+ t.Fatalf("encoding err \ncpt %x\nhex %x\ngot %x\nexp %x\n",
+ key, hexOrig, got, exp)
+ }
+ }
+}
+
+func BenchmarkHexToCompact(b *testing.B) {
+ testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
+ for i := 0; i < b.N; i++ {
+ hexToCompact(testBytes)
+ }
+}
+
+func BenchmarkCompactToHex(b *testing.B) {
+ testBytes := []byte{0, 15, 1, 12, 11, 8, 16 /*term*/}
+ for i := 0; i < b.N; i++ {
+ compactToHex(testBytes)
+ }
+}
+
+func BenchmarkKeybytesToHex(b *testing.B) {
+ testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
+ for i := 0; i < b.N; i++ {
+ keybytesToHex(testBytes)
+ }
+}
+
+func BenchmarkHexToKeybytes(b *testing.B) {
+ testBytes := []byte{7, 6, 6, 5, 7, 2, 6, 2, 16}
+ for i := 0; i < b.N; i++ {
+ hexToKeybytes(testBytes)
+ }
+}
diff --git a/trie/errors.go b/trie/errors.go
new file mode 100644
index 0000000000..75b5ef74a3
--- /dev/null
+++ b/trie/errors.go
@@ -0,0 +1,47 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// MissingNodeError is returned by the trie functions (TryGet, TryUpdate, TryDelete)
+// in the case where a trie node is not present in the local database. It contains
+// information necessary for retrieving the missing node.
+type MissingNodeError struct {
+ Owner types.Hash // owner of the trie if it's 2-layered trie
+ NodeHash types.Hash // hash of the missing node
+ Path []byte // hex-encoded path to the missing node
+ err error // concrete error for missing trie node
+}
+
+// Unwrap returns the concrete error for missing trie node which
+// allows us for further analysis outside.
+func (err *MissingNodeError) Unwrap() error {
+ return err.err
+}
+
+func (err *MissingNodeError) Error() string {
+ if err.Owner == types.ZeroHash {
+ return fmt.Sprintf("missing trie node %x (path %x) %v", err.NodeHash, err.Path, err.err)
+ }
+
+ return fmt.Sprintf("missing trie node %x (owner %x) (path %x) %v", err.NodeHash, err.Owner, err.Path, err.err)
+}
diff --git a/trie/hasher.go b/trie/hasher.go
new file mode 100644
index 0000000000..1a17885b21
--- /dev/null
+++ b/trie/hasher.go
@@ -0,0 +1,232 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "sync"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "golang.org/x/crypto/sha3"
+)
+
+// hasher is a type used for the trie Hash operation. A hasher has some
+// internal preallocated temp space
+type hasher struct {
+ sha crypto.KeccakState
+ tmp []byte
+ encbuf rlp.EncoderBuffer
+ parallel bool // Whether to use parallel threads when hashing
+}
+
+// hasherPool holds pureHashers
+var hasherPool = sync.Pool{
+ New: func() interface{} {
+ return &hasher{
+ tmp: make([]byte, 0, 550), // cap is as large as a full fullNode.
+ //nolint:forcetypeassert
+ sha: sha3.NewLegacyKeccak256().(crypto.KeccakState),
+ encbuf: rlp.NewEncoderBuffer(nil),
+ }
+ },
+}
+
+func newHasher(parallel bool) *hasher {
+ //nolint:forcetypeassert
+ h := hasherPool.Get().(*hasher)
+ h.parallel = parallel
+
+ return h
+}
+
+func returnHasherToPool(h *hasher) {
+ hasherPool.Put(h)
+}
+
+// hash collapses a node down into a hash node, also returning a copy of the
+// original node initialized with the computed hash to replace the original one.
+func (h *hasher) hash(n node, force bool) (hashed node, cached node) {
+ // Return the cached hash if it's available
+ if hash, _ := n.cache(); hash != nil {
+ return hash, n
+ }
+
+ // Trie not processed yet, walk the children
+ switch n := n.(type) {
+ case *shortNode:
+ collapsed, cached := h.hashShortNodeChildren(n)
+ hashed := h.shortnodeToHash(collapsed, force)
+ // We need to retain the possibly _not_ hashed node, in case it was too
+ // small to be hashed
+ if hn, ok := hashed.(hashNode); ok {
+ cached.flags.hash = hn
+ } else {
+ cached.flags.hash = nil
+ }
+
+ return hashed, cached
+ case *fullNode:
+ collapsed, cached := h.hashFullNodeChildren(n)
+ hashed = h.fullnodeToHash(collapsed, force)
+
+ if hn, ok := hashed.(hashNode); ok {
+ cached.flags.hash = hn
+ } else {
+ cached.flags.hash = nil
+ }
+
+ return hashed, cached
+ default:
+ // Value and hash nodes don't have children so they're left as were
+ return n, n
+ }
+}
+
+// hashShortNodeChildren collapses the short node. The returned collapsed node
+// holds a live reference to the Key, and must not be modified.
+// The cached
+func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) {
+ // Hash the short node's child, caching the newly hashed subtree
+ collapsed, cached = n.copy(), n.copy()
+ // Previously, we did copy this one. We don't seem to need to actually
+ // do that, since we don't overwrite/reuse keys
+ //cached.Key = common.CopyBytes(n.Key)
+ collapsed.Key = hexToCompact(n.Key)
+
+ // Unless the child is a valuenode or hashnode, hash it
+ switch n.Val.(type) {
+ case *fullNode, *shortNode:
+ collapsed.Val, cached.Val = h.hash(n.Val, false)
+ }
+
+ return collapsed, cached
+}
+
+func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) {
+ // Hash the full node's children, caching the newly hashed subtrees
+ cached = n.copy()
+ collapsed = n.copy()
+
+ if h.parallel {
+ var wg sync.WaitGroup
+
+ wg.Add(16)
+
+ for i := 0; i < 16; i++ {
+ go func(i int) {
+ hasher := newHasher(false)
+ if child := n.Children[i]; child != nil {
+ collapsed.Children[i], cached.Children[i] = hasher.hash(child, false)
+ } else {
+ collapsed.Children[i] = nilValueNode
+ }
+
+ returnHasherToPool(hasher)
+
+ wg.Done()
+ }(i)
+ }
+
+ wg.Wait()
+ } else {
+ for i := 0; i < 16; i++ {
+ if child := n.Children[i]; child != nil {
+ collapsed.Children[i], cached.Children[i] = h.hash(child, false)
+ } else {
+ collapsed.Children[i] = nilValueNode
+ }
+ }
+ }
+
+ return collapsed, cached
+}
+
+// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode
+// should have hex-type Key, which will be converted (without modification)
+// into compact form for RLP encoding.
+// If the rlp data is smaller than 32 bytes, `nil` is returned.
+func (h *hasher) shortnodeToHash(n *shortNode, force bool) node {
+ n.encode(h.encbuf)
+ enc := h.encodedBytes()
+
+ if len(enc) < 32 && !force {
+ return n // Nodes smaller than 32 bytes are stored inside their parent
+ }
+
+ return h.hashData(enc)
+}
+
+// shortnodeToHash is used to creates a hashNode from a set of hashNodes, (which
+// may contain nil values)
+func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
+ n.encode(h.encbuf)
+ enc := h.encodedBytes()
+
+ if len(enc) < 32 && !force {
+ return n // Nodes smaller than 32 bytes are stored inside their parent
+ }
+
+ return h.hashData(enc)
+}
+
+// encodedBytes returns the result of the last encoding operation on h.encbuf.
+// This also resets the encoder buffer.
+//
+// All node encoding must be done like this:
+//
+// node.encode(h.encbuf)
+// enc := h.encodedBytes()
+//
+// This convention exists because node.encode can only be inlined/escape-analyzed when
+// called on a concrete receiver type.
+func (h *hasher) encodedBytes() []byte {
+ h.tmp = h.encbuf.AppendToBytes(h.tmp[:0])
+ h.encbuf.Reset(nil)
+
+ return h.tmp
+}
+
+// hashData hashes the provided data
+func (h *hasher) hashData(data []byte) hashNode {
+ n := make(hashNode, 32)
+
+ h.sha.Reset()
+ h.sha.Write(data)
+ h.sha.Read(n)
+
+ return n
+}
+
+// proofHash is used to construct trie proofs, and returns the 'collapsed'
+// node (for later RLP encoding) as well as the hashed node -- unless the
+// node is smaller than 32 bytes, in which case it will be returned as is.
+// This method does not do anything on value- or hash-nodes.
+func (h *hasher) proofHash(original node) (collapsed, hashed node) {
+ switch n := original.(type) {
+ case *shortNode:
+ sn, _ := h.hashShortNodeChildren(n)
+
+ return sn, h.shortnodeToHash(sn, false)
+ case *fullNode:
+ fn, _ := h.hashFullNodeChildren(n)
+
+ return fn, h.fullnodeToHash(fn, false)
+ default:
+ // Value and hash nodes don't have children so they're left as were
+ return n, n
+ }
+}
diff --git a/trie/iterator.go b/trie/iterator.go
new file mode 100644
index 0000000000..0722daa768
--- /dev/null
+++ b/trie/iterator.go
@@ -0,0 +1,845 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "container/heap"
+ "errors"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Iterator is a key-value trie iterator that traverses a Trie.
+type Iterator struct {
+ nodeIt NodeIterator
+
+ Key []byte // Current data key on which the iterator is positioned on
+ Value []byte // Current data value on which the iterator is positioned on
+ Err error
+}
+
+// NewIterator creates a new key-value iterator from a node iterator.
+// Note that the value returned by the iterator is raw. If the content is encoded
+// (e.g. storage value is RLP-encoded), it's caller's duty to decode it.
+func NewIterator(it NodeIterator) *Iterator {
+ return &Iterator{
+ nodeIt: it,
+ }
+}
+
+// Next moves the iterator forward one key-value entry.
+func (it *Iterator) Next() bool {
+ for it.nodeIt.Next(true) {
+ if it.nodeIt.Leaf() {
+ it.Key = it.nodeIt.LeafKey()
+ it.Value = it.nodeIt.LeafBlob()
+
+ return true
+ }
+ }
+
+ it.Key = nil
+ it.Value = nil
+ it.Err = it.nodeIt.Error()
+
+ return false
+}
+
+// Prove generates the Merkle proof for the leaf node the iterator is currently
+// positioned on.
+func (it *Iterator) Prove() [][]byte {
+ return it.nodeIt.LeafProof()
+}
+
+// NodeIterator is an iterator to traverse the trie pre-order.
+type NodeIterator interface {
+ // Next moves the iterator to the next node. If the parameter is false, any child
+ // nodes will be skipped.
+ Next(bool) bool
+
+ // Error returns the error status of the iterator.
+ Error() error
+
+ // Hash returns the hash of the current node.
+ Hash() types.Hash
+
+ // Parent returns the hash of the parent of the current node. The hash may be the one
+ // grandparent if the immediate parent is an internal node with no hash.
+ Parent() types.Hash
+
+ // Path returns the hex-encoded path to the current node.
+ // Callers must not retain references to the return value after calling Next.
+ // For leaf nodes, the last element of the path is the 'terminator symbol' 0x10.
+ Path() []byte
+
+ // NodeBlob returns the rlp-encoded value of the current iterated node.
+ // If the node is an embedded node in its parent, nil is returned then.
+ NodeBlob() []byte
+
+ // Leaf returns true iff the current node is a leaf node.
+ Leaf() bool
+
+ // LeafKey returns the key of the leaf. The method panics if the iterator is not
+ // positioned at a leaf. Callers must not retain references to the value after
+ // calling Next.
+ LeafKey() []byte
+
+ // LeafBlob returns the content of the leaf. The method panics if the iterator
+ // is not positioned at a leaf. Callers must not retain references to the value
+ // after calling Next.
+ LeafBlob() []byte
+
+ // LeafProof returns the Merkle proof of the leaf. The method panics if the
+ // iterator is not positioned at a leaf. Callers must not retain references
+ // to the value after calling Next.
+ LeafProof() [][]byte
+
+ // AddResolver sets an intermediate database to use for looking up trie nodes
+ // before reaching into the real persistent layer.
+ //
+ // This is not required for normal operation, rather is an optimization for
+ // cases where trie nodes can be recovered from some external mechanism without
+ // reading from disk. In those cases, this resolver allows short circuiting
+ // accesses and returning them from memory.
+ //
+ // Before adding a similar mechanism to any other place in Geth, consider
+ // making trie.Database an interface and wrapping at that level. It's a huge
+ // refactor, but it could be worth it if another occurrence arises.
+ AddResolver(kvdb.KVReader)
+}
+
+// nodeIteratorState represents the iteration state at one particular node of the
+// trie, which can be resumed at a later invocation.
+type nodeIteratorState struct {
+ hash types.Hash // Hash of the node being iterated (nil if not standalone)
+ node node // Trie node being iterated
+ parent types.Hash // Hash of the first full ancestor node (nil if current is the root)
+ index int // Child to be processed next
+ pathlen int // Length of the path to this node
+}
+
+type nodeIterator struct {
+ trie *Trie // Trie being iterated
+ stack []*nodeIteratorState // Hierarchy of trie nodes persisting the iteration state
+ path []byte // Path to the current node
+ err error // Failure set in case of an internal error in the iterator
+
+ resolver kvdb.KVReader // Optional intermediate resolver above the disk layer
+}
+
+// errIteratorEnd is stored in nodeIterator.err when iteration is done.
+var errIteratorEnd = errors.New("end of iteration")
+
+// seekError is stored in nodeIterator.err if the initial seek has failed.
+type seekError struct {
+ key []byte
+ err error
+}
+
+func (e seekError) Error() string {
+ return "seek error: " + e.err.Error()
+}
+
+func newNodeIterator(trie *Trie, start []byte) NodeIterator {
+ if trie.Hash() == types.EmptyRootHash {
+ return &nodeIterator{
+ trie: trie,
+ err: errIteratorEnd,
+ }
+ }
+
+ it := &nodeIterator{trie: trie}
+ it.err = it.seek(start)
+
+ return it
+}
+
+func (it *nodeIterator) AddResolver(resolver kvdb.KVReader) {
+ it.resolver = resolver
+}
+
+func (it *nodeIterator) Hash() types.Hash {
+ if len(it.stack) == 0 {
+ return types.Hash{}
+ }
+
+ return it.stack[len(it.stack)-1].hash
+}
+
+func (it *nodeIterator) Parent() types.Hash {
+ if len(it.stack) == 0 {
+ return types.Hash{}
+ }
+
+ return it.stack[len(it.stack)-1].parent
+}
+
+func (it *nodeIterator) Leaf() bool {
+ return hasTerm(it.path)
+}
+
+func (it *nodeIterator) LeafKey() []byte {
+ if len(it.stack) > 0 {
+ if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
+ return hexToKeybytes(it.path)
+ }
+ }
+
+ panic("not at leaf")
+}
+
+func (it *nodeIterator) LeafBlob() []byte {
+ if len(it.stack) > 0 {
+ if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
+ return node
+ }
+ }
+
+ panic("not at leaf")
+}
+
+func (it *nodeIterator) LeafProof() [][]byte {
+ if len(it.stack) > 0 {
+ if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ proofs := make([][]byte, 0, len(it.stack))
+
+ for i, item := range it.stack[:len(it.stack)-1] {
+ // Gather nodes that end up as hash nodes (or the root)
+ node, hashed := hasher.proofHash(item.node)
+
+ if _, ok := hashed.(hashNode); ok || i == 0 {
+ proofs = append(proofs, nodeToBytes(node))
+ }
+ }
+
+ return proofs
+ }
+ }
+
+ panic("not at leaf")
+}
+
+func (it *nodeIterator) Path() []byte {
+ return it.path
+}
+
+func (it *nodeIterator) NodeBlob() []byte {
+ if it.Hash() == (types.Hash{}) {
+ return nil // skip the non-standalone node
+ }
+
+ blob, err := it.resolveBlob(it.Hash().Bytes(), it.Path())
+ if err != nil {
+ it.err = err
+
+ return nil
+ }
+
+ return blob
+}
+
+func (it *nodeIterator) Error() error {
+ if errors.Is(it.err, errIteratorEnd) {
+ return nil
+ }
+
+ var seek seekError
+ if errors.As(it.err, &seek) {
+ return seek.err
+ }
+
+ return it.err
+}
+
+// Next moves the iterator to the next node, returning whether there are any
+// further nodes. In case of an internal error this method returns false and
+// sets the Error field to the encountered failure. If `descend` is false,
+// skips iterating over any subnodes of the current node.
+func (it *nodeIterator) Next(descend bool) bool {
+ if errors.Is(it.err, errIteratorEnd) {
+ return false
+ }
+
+ var seek seekError
+ if errors.As(it.err, &seek) {
+ if it.err = it.seek(seek.key); it.err != nil {
+ return false
+ }
+ }
+
+ // Otherwise step forward with the iterator and report any errors.
+ state, parentIndex, path, err := it.peek(descend)
+ it.err = err
+
+ if it.err != nil {
+ return false
+ }
+
+ it.push(state, parentIndex, path)
+
+ return true
+}
+
+func (it *nodeIterator) seek(prefix []byte) error {
+ // The path we're looking for is the hex encoded key without terminator.
+ key := keybytesToHex(prefix)
+ key = key[:len(key)-1]
+ // Move forward until we're just before the closest match to key.
+ for {
+ state, parentIndex, path, err := it.peekSeek(key)
+ if errors.Is(err, errIteratorEnd) {
+ return errIteratorEnd
+ } else if err != nil {
+ return seekError{prefix, err}
+ } else if bytes.Compare(path, key) >= 0 {
+ return nil
+ }
+
+ it.push(state, parentIndex, path)
+ }
+}
+
+// init initializes the iterator.
+func (it *nodeIterator) init() (*nodeIteratorState, error) {
+ root := it.trie.Hash()
+ state := &nodeIteratorState{node: it.trie.root, index: -1}
+
+ if root != types.EmptyRootHash {
+ state.hash = root
+ }
+
+ return state, state.resolve(it, nil)
+}
+
+// peek creates the next state of the iterator.
+func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) {
+ // Initialize the iterator if we've just started.
+ if len(it.stack) == 0 {
+ state, err := it.init()
+
+ return state, nil, nil, err
+ }
+
+ if !descend {
+ // If we're skipping children, pop the current node first
+ it.pop()
+ }
+
+ // Continue iteration to the next child
+ for len(it.stack) > 0 {
+ parent := it.stack[len(it.stack)-1]
+ ancestor := parent.hash
+
+ if ancestor == types.ZeroHash {
+ ancestor = parent.parent
+ }
+
+ state, path, ok := it.nextChild(parent, ancestor)
+ if ok {
+ if err := state.resolve(it, path); err != nil {
+ return parent, &parent.index, path, err
+ }
+
+ return state, &parent.index, path, nil
+ }
+ // No more child nodes, move back up.
+ it.pop()
+ }
+
+ return nil, nil, nil, errIteratorEnd
+}
+
+// peekSeek is like peek, but it also tries to skip resolving hashes by skipping
+// over the siblings that do not lead towards the desired seek position.
+func (it *nodeIterator) peekSeek(seekKey []byte) (*nodeIteratorState, *int, []byte, error) {
+ // Initialize the iterator if we've just started.
+ if len(it.stack) == 0 {
+ state, err := it.init()
+
+ return state, nil, nil, err
+ }
+
+ if !bytes.HasPrefix(seekKey, it.path) {
+ // If we're skipping children, pop the current node first
+ it.pop()
+ }
+
+ // Continue iteration to the next child
+ for len(it.stack) > 0 {
+ parent := it.stack[len(it.stack)-1]
+
+ ancestor := parent.hash
+ if ancestor == types.ZeroHash {
+ ancestor = parent.parent
+ }
+
+ state, path, ok := it.nextChildAt(parent, ancestor, seekKey)
+ if ok {
+ if err := state.resolve(it, path); err != nil {
+ return parent, &parent.index, path, err
+ }
+
+ return state, &parent.index, path, nil
+ }
+ // No more child nodes, move back up.
+ it.pop()
+ }
+
+ return nil, nil, nil, errIteratorEnd
+}
+
+func (it *nodeIterator) resolveHash(hash hashNode, path []byte) (node, error) {
+ if it.resolver != nil {
+ if blob, _, err := it.resolver.Get(hash); err == nil && len(blob) > 0 {
+ if resolved, err := decodeNode(hash, blob); err == nil {
+ return resolved, nil
+ }
+ }
+ }
+ // Retrieve the specified node from the underlying node reader.
+ // it.trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
+ return it.trie.reader.node(path, types.BytesToHash(hash))
+}
+
+func (it *nodeIterator) resolveBlob(hash hashNode, path []byte) ([]byte, error) {
+ if it.resolver != nil {
+ if blob, _, err := it.resolver.Get(hash); err == nil && len(blob) > 0 {
+ return blob, nil
+ }
+ }
+ // Retrieve the specified node from the underlying node reader.
+ // it.trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
+ return it.trie.reader.nodeBlob(path, types.BytesToHash(hash))
+}
+
+func (st *nodeIteratorState) resolve(it *nodeIterator, path []byte) error {
+ if hash, ok := st.node.(hashNode); ok {
+ resolved, err := it.resolveHash(hash, path)
+ if err != nil {
+ return err
+ }
+
+ st.node = resolved
+ st.hash = types.BytesToHash(hash)
+ }
+
+ return nil
+}
+
+func findChild(n *fullNode, index int, path []byte, ancestor types.Hash) (node, *nodeIteratorState, []byte, int) {
+ var (
+ child node
+ state *nodeIteratorState
+ childPath []byte
+ )
+
+ for ; index < len(n.Children); index++ {
+ if n.Children[index] != nil {
+ child = n.Children[index]
+ hash, _ := child.cache()
+ state = &nodeIteratorState{
+ hash: types.BytesToHash(hash),
+ node: child,
+ parent: ancestor,
+ index: -1,
+ pathlen: len(path),
+ }
+
+ childPath = append(childPath, path...)
+ childPath = append(childPath, byte(index))
+
+ return child, state, childPath, index
+ }
+ }
+
+ return nil, nil, nil, 0
+}
+
+func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor types.Hash) (*nodeIteratorState, []byte, bool) {
+ switch node := parent.node.(type) {
+ case *fullNode:
+ // Full node, move to the first non-nil child.
+ if child, state, path, index := findChild(node, parent.index+1, it.path, ancestor); child != nil {
+ parent.index = index - 1
+
+ return state, path, true
+ }
+ case *shortNode:
+ // Short node, return the pointer singleton child
+ if parent.index < 0 {
+ hash, _ := node.Val.cache()
+ state := &nodeIteratorState{
+ hash: types.BytesToHash(hash),
+ node: node.Val,
+ parent: ancestor,
+ index: -1,
+ pathlen: len(it.path),
+ }
+
+ path := append(it.path, node.Key...)
+
+ return state, path, true
+ }
+ }
+
+ return parent, it.path, false
+}
+
+// nextChildAt is similar to nextChild, except that it targets a child as close to the
+// target key as possible, thus skipping siblings.
+func (it *nodeIterator) nextChildAt(
+ parent *nodeIteratorState,
+ ancestor types.Hash,
+ key []byte,
+) (*nodeIteratorState, []byte, bool) {
+ switch n := parent.node.(type) {
+ case *fullNode:
+ // Full node, move to the first non-nil child before the desired key position
+ child, state, path, index := findChild(n, parent.index+1, it.path, ancestor)
+ if child == nil {
+ // No more children in this fullnode
+ return parent, it.path, false
+ }
+ // If the child we found is already past the seek position, just return it.
+ if bytes.Compare(path, key) >= 0 {
+ parent.index = index - 1
+
+ return state, path, true
+ }
+ // The child is before the seek position. Try advancing
+ for {
+ nextChild, nextState, nextPath, nextIndex := findChild(n, index+1, it.path, ancestor)
+ // If we run out of children, or skipped past the target, return the
+ // previous one
+ if nextChild == nil || bytes.Compare(nextPath, key) >= 0 {
+ parent.index = index - 1
+
+ return state, path, true
+ }
+ // We found a better child closer to the target
+ state, path, index = nextState, nextPath, nextIndex
+ }
+ case *shortNode:
+ // Short node, return the pointer singleton child
+ if parent.index < 0 {
+ hash, _ := n.Val.cache()
+ state := &nodeIteratorState{
+ hash: types.BytesToHash(hash),
+ node: n.Val,
+ parent: ancestor,
+ index: -1,
+ pathlen: len(it.path),
+ }
+
+ path := append(it.path, n.Key...)
+
+ return state, path, true
+ }
+ }
+
+ return parent, it.path, false
+}
+
+func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) {
+ it.path = path
+ it.stack = append(it.stack, state)
+
+ if parentIndex != nil {
+ *parentIndex++
+ }
+}
+
+func (it *nodeIterator) pop() {
+ last := it.stack[len(it.stack)-1]
+ it.path = it.path[:last.pathlen]
+ it.stack[len(it.stack)-1] = nil
+ it.stack = it.stack[:len(it.stack)-1]
+}
+
+func compareNodes(a, b NodeIterator) int {
+ if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
+ return cmp
+ }
+
+ if a.Leaf() && !b.Leaf() {
+ return -1
+ } else if b.Leaf() && !a.Leaf() {
+ return 1
+ }
+
+ if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
+ return cmp
+ }
+
+ if a.Leaf() && b.Leaf() {
+ return bytes.Compare(a.LeafBlob(), b.LeafBlob())
+ }
+
+ return 0
+}
+
+type differenceIterator struct {
+ a, b NodeIterator // Nodes returned are those in b - a.
+ eof bool // Indicates a has run out of elements
+ count int // Number of nodes scanned on either trie
+}
+
+// NewDifferenceIterator constructs a NodeIterator that iterates over elements in b that
+// are not in a. Returns the iterator, and a pointer to an integer recording the number
+// of nodes seen.
+func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
+ a.Next(true)
+
+ it := &differenceIterator{
+ a: a,
+ b: b,
+ }
+
+ return it, &it.count
+}
+
+func (it *differenceIterator) Hash() types.Hash {
+ return it.b.Hash()
+}
+
+func (it *differenceIterator) Parent() types.Hash {
+ return it.b.Parent()
+}
+
+func (it *differenceIterator) Leaf() bool {
+ return it.b.Leaf()
+}
+
+func (it *differenceIterator) LeafKey() []byte {
+ return it.b.LeafKey()
+}
+
+func (it *differenceIterator) LeafBlob() []byte {
+ return it.b.LeafBlob()
+}
+
+func (it *differenceIterator) LeafProof() [][]byte {
+ return it.b.LeafProof()
+}
+
+func (it *differenceIterator) Path() []byte {
+ return it.b.Path()
+}
+
+func (it *differenceIterator) NodeBlob() []byte {
+ return it.b.NodeBlob()
+}
+
+func (it *differenceIterator) AddResolver(resolver kvdb.KVReader) {
+ panic("not implemented")
+}
+
+func (it *differenceIterator) Next(bool) bool {
+ // Invariants:
+ // - We always advance at least one element in b.
+ // - At the start of this function, a's path is lexically greater than b's.
+ if !it.b.Next(true) {
+ return false
+ }
+
+ it.count++
+
+ if it.eof {
+ // a has reached eof, so we just return all elements from b
+ return true
+ }
+
+ for {
+ switch compareNodes(it.a, it.b) {
+ case -1:
+ // b jumped past a; advance a
+ if !it.a.Next(true) {
+ it.eof = true
+
+ return true
+ }
+
+ it.count++
+ case 1:
+ // b is before a
+ return true
+ case 0:
+ // a and b are identical; skip this whole subtree if the nodes have hashes
+ hasHash := it.a.Hash() == types.ZeroHash
+
+ if !it.b.Next(hasHash) {
+ return false
+ }
+
+ it.count++
+
+ if !it.a.Next(hasHash) {
+ it.eof = true
+
+ return true
+ }
+
+ it.count++
+ }
+ }
+}
+
+func (it *differenceIterator) Error() error {
+ if err := it.a.Error(); err != nil {
+ return err
+ }
+
+ return it.b.Error()
+}
+
+type nodeIteratorHeap []NodeIterator
+
+func (h nodeIteratorHeap) Len() int { return len(h) }
+func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 }
+func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h *nodeIteratorHeap) Push(x interface{}) {
+ //nolint:forcetypeassert
+ *h = append(*h, x.(NodeIterator))
+}
+func (h *nodeIteratorHeap) Pop() interface{} {
+ n := len(*h)
+ x := (*h)[n-1]
+ *h = (*h)[0 : n-1]
+
+ return x
+}
+
+type unionIterator struct {
+ items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
+ count int // Number of nodes scanned across all tries
+}
+
+// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
+// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
+// recording the number of nodes visited.
+func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
+ h := make(nodeIteratorHeap, len(iters))
+ copy(h, iters)
+ heap.Init(&h)
+
+ ui := &unionIterator{items: &h}
+
+ return ui, &ui.count
+}
+
+func (it *unionIterator) Hash() types.Hash {
+ return (*it.items)[0].Hash()
+}
+
+func (it *unionIterator) Parent() types.Hash {
+ return (*it.items)[0].Parent()
+}
+
+func (it *unionIterator) Leaf() bool {
+ return (*it.items)[0].Leaf()
+}
+
+func (it *unionIterator) LeafKey() []byte {
+ return (*it.items)[0].LeafKey()
+}
+
+func (it *unionIterator) LeafBlob() []byte {
+ return (*it.items)[0].LeafBlob()
+}
+
+func (it *unionIterator) LeafProof() [][]byte {
+ return (*it.items)[0].LeafProof()
+}
+
+func (it *unionIterator) Path() []byte {
+ return (*it.items)[0].Path()
+}
+
+func (it *unionIterator) NodeBlob() []byte {
+ return (*it.items)[0].NodeBlob()
+}
+
+func (it *unionIterator) AddResolver(resolver kvdb.KVReader) {
+ panic("not implemented")
+}
+
+// Next returns the next node in the union of tries being iterated over.
+//
+// It does this by maintaining a heap of iterators, sorted by the iteration
+// order of their next elements, with one entry for each source trie. Each
+// time Next() is called, it takes the least element from the heap to return,
+// advancing any other iterators that also point to that same element. These
+// iterators are called with descend=false, since we know that any nodes under
+// these nodes will also be duplicates, found in the currently selected iterator.
+// Whenever an iterator is advanced, it is pushed back into the heap if it still
+// has elements remaining.
+//
+// In the case that descend=false - eg, we're asked to ignore all subnodes of the
+// current node - we also advance any iterators in the heap that have the current
+// path as a prefix.
+func (it *unionIterator) Next(descend bool) bool {
+ if len(*it.items) == 0 {
+ return false
+ }
+
+ // Get the next key from the union
+ //nolint:forcetypeassert
+ least := heap.Pop(it.items).(NodeIterator)
+
+ // Skip over other nodes as long as they're identical, or, if we're not descending, as
+ // long as they have the same prefix as the current node.
+ for len(*it.items) > 0 &&
+ ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) ||
+ compareNodes(least, (*it.items)[0]) == 0) {
+ //nolint:forcetypeassert
+ skipped := heap.Pop(it.items).(NodeIterator)
+ // Skip the whole subtree if the nodes have hashes; otherwise just skip this node
+ if skipped.Next(skipped.Hash() == types.ZeroHash) {
+ it.count++
+ // If there are more elements, push the iterator back on the heap
+ heap.Push(it.items, skipped)
+ }
+ }
+
+ if least.Next(descend) {
+ it.count++
+ heap.Push(it.items, least)
+ }
+
+ return len(*it.items) > 0
+}
+
+func (it *unionIterator) Error() error {
+ for i := 0; i < len(*it.items); i++ {
+ if err := (*it.items)[i].Error(); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
new file mode 100644
index 0000000000..472dc9a2fe
--- /dev/null
+++ b/trie/iterator_test.go
@@ -0,0 +1,616 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+func TestEmptyIterator(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ iter := trie.NodeIterator(nil)
+
+ seen := make(map[string]struct{})
+ for iter.Next(true) {
+ seen[string(iter.Path())] = struct{}{}
+ }
+ if len(seen) != 0 {
+ t.Fatal("Unexpected trie node iterated")
+ }
+}
+
+func TestIterator(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ db := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(db)
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ all := make(map[string]string)
+ for _, val := range vals {
+ all[val.k] = val.v
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, err := trie.Commit(false)
+ if err != nil {
+ t.Fatalf("Failed to commit trie %v", err)
+ }
+ db.Update(NewWithNodeSet(nodes))
+
+ trie, _ = New(TrieID(root), db, logger)
+ found := make(map[string]string)
+ it := NewIterator(trie.NodeIterator(nil))
+ for it.Next() {
+ found[string(it.Key)] = string(it.Value)
+ }
+
+ for k, v := range all {
+ if found[k] != v {
+ t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
+ }
+ }
+}
+
+type kv struct {
+ k, v []byte
+ t bool
+}
+
+func TestIteratorLargeData(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ vals := make(map[string]*kv)
+
+ for i := byte(0); i < 255; i++ {
+ value := &kv{types.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{types.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+
+ it := NewIterator(trie.NodeIterator(nil))
+ for it.Next() {
+ vals[string(it.Key)].t = true
+ }
+
+ var untouched []*kv
+ for _, value := range vals {
+ if !value.t {
+ untouched = append(untouched, value)
+ }
+ }
+
+ if len(untouched) > 0 {
+ t.Errorf("Missed %d nodes", len(untouched))
+ for _, value := range untouched {
+ t.Error(value)
+ }
+ }
+}
+
+// Tests that the node iterator indeed walks over the entire database contents.
+func TestNodeIteratorCoverage(t *testing.T) {
+ // Create some arbitrary test trie to iterate
+ db, trie, _ := makeTestTrie()
+
+ // Gather all the node hashes found by the iterator
+ hashes := make(map[types.Hash]struct{})
+ for it := trie.NodeIterator(nil); it.Next(true); {
+ if it.Hash() != (types.Hash{}) {
+ hashes[it.Hash()] = struct{}{}
+ }
+ }
+ // Cross check the hashes and the database itself
+ for hash := range hashes {
+ if _, err := db.Node(hash); err != nil {
+ t.Errorf("failed to retrieve reported node %x: %v", hash, err)
+ }
+ }
+ for hash, obj := range db.dirties {
+ if obj != nil && hash != (types.Hash{}) {
+ if _, ok := hashes[hash]; !ok {
+ t.Errorf("state entry not reported %x", hash)
+ }
+ }
+ }
+ it := db.diskdb.NewIterator(nil, nil)
+ for it.Next() {
+ key := it.Key()
+ if _, ok := hashes[types.BytesToHash(key)]; !ok {
+ t.Errorf("state entry not reported %x", key)
+ }
+ }
+ it.Release()
+}
+
+type kvs struct{ k, v string }
+
+var testdata1 = []kvs{
+ {"barb", "ba"},
+ {"bard", "bc"},
+ {"bars", "bb"},
+ {"bar", "b"},
+ {"fab", "z"},
+ {"food", "ab"},
+ {"foos", "aa"},
+ {"foo", "a"},
+}
+
+var testdata2 = []kvs{
+ {"aardvark", "c"},
+ {"bar", "b"},
+ {"barb", "bd"},
+ {"bars", "be"},
+ {"fab", "z"},
+ {"foo", "a"},
+ {"foos", "aa"},
+ {"food", "ab"},
+ {"jars", "d"},
+}
+
+func TestIteratorSeek(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for _, val := range testdata1 {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+
+ // Seek to the middle.
+ it := NewIterator(trie.NodeIterator([]byte("fab")))
+ if err := checkIteratorOrder(testdata1[4:], it); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seek to a non-existent key.
+ it = NewIterator(trie.NodeIterator([]byte("barc")))
+ if err := checkIteratorOrder(testdata1[1:], it); err != nil {
+ t.Fatal(err)
+ }
+
+ // Seek beyond the end.
+ it = NewIterator(trie.NodeIterator([]byte("z")))
+ if err := checkIteratorOrder(nil, it); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func checkIteratorOrder(want []kvs, it *Iterator) error {
+ for it.Next() {
+ if len(want) == 0 {
+ return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
+ }
+ if !bytes.Equal(it.Key, []byte(want[0].k)) {
+ return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
+ }
+ want = want[1:]
+ }
+ if len(want) > 0 {
+ return fmt.Errorf("iterator ended early, want key %q", want[0])
+ }
+ return nil
+}
+
+func TestDifferenceIterator(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ dba := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ triea := NewEmpty(dba)
+ for _, val := range testdata1 {
+ triea.Update([]byte(val.k), []byte(val.v))
+ }
+ rootA, nodesA, _ := triea.Commit(false)
+ dba.Update(NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba, logger)
+
+ dbb := NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger())
+ trieb := NewEmpty(dbb)
+ for _, val := range testdata2 {
+ trieb.Update([]byte(val.k), []byte(val.v))
+ }
+ rootB, nodesB, _ := trieb.Commit(false)
+ dbb.Update(NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb, logger)
+
+ found := make(map[string]string)
+ di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
+ it := NewIterator(di)
+ for it.Next() {
+ found[string(it.Key)] = string(it.Value)
+ }
+
+ all := []struct{ k, v string }{
+ {"aardvark", "c"},
+ {"barb", "bd"},
+ {"bars", "be"},
+ {"jars", "d"},
+ }
+ for _, item := range all {
+ if found[item.k] != item.v {
+ t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
+ }
+ }
+ if len(found) != len(all) {
+ t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
+ }
+}
+
+func TestUnionIterator(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ dba := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ triea := NewEmpty(dba)
+ for _, val := range testdata1 {
+ triea.Update([]byte(val.k), []byte(val.v))
+ }
+ rootA, nodesA, _ := triea.Commit(false)
+ dba.Update(NewWithNodeSet(nodesA))
+ triea, _ = New(TrieID(rootA), dba, logger)
+
+ dbb := NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger())
+ trieb := NewEmpty(dbb)
+ for _, val := range testdata2 {
+ trieb.Update([]byte(val.k), []byte(val.v))
+ }
+ rootB, nodesB, _ := trieb.Commit(false)
+ dbb.Update(NewWithNodeSet(nodesB))
+ trieb, _ = New(TrieID(rootB), dbb, logger)
+
+ di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
+ it := NewIterator(di)
+
+ all := []struct{ k, v string }{
+ {"aardvark", "c"},
+ {"barb", "ba"},
+ {"barb", "bd"},
+ {"bard", "bc"},
+ {"bars", "bb"},
+ {"bars", "be"},
+ {"bar", "b"},
+ {"fab", "z"},
+ {"food", "ab"},
+ {"foos", "aa"},
+ {"foo", "a"},
+ {"jars", "d"},
+ }
+
+ for i, kv := range all {
+ if !it.Next() {
+ t.Errorf("Iterator ends prematurely at element %d", i)
+ }
+ if kv.k != string(it.Key) {
+ t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
+ }
+ if kv.v != string(it.Value) {
+ t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
+ }
+ }
+ if it.Next() {
+ t.Errorf("Iterator returned extra values.")
+ }
+}
+
+func TestIteratorNoDups(t *testing.T) {
+ tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for _, val := range testdata1 {
+ tr.Update([]byte(val.k), []byte(val.v))
+ }
+ checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+}
+
+// This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
+func TestIteratorContinueAfterErrorDisk(t *testing.T) { testIteratorContinueAfterError(t, false) }
+func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
+
+func testIteratorContinueAfterError(t *testing.T, memonly bool) {
+ t.Helper()
+
+ diskdb := rawdb.NewMemoryDatabase()
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(diskdb, logger)
+
+ tr := NewEmpty(triedb)
+ for _, val := range testdata1 {
+ tr.Update([]byte(val.k), []byte(val.v))
+ }
+ _, nodes, _ := tr.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ if !memonly {
+ triedb.Commit(tr.Hash(), true, nil)
+ }
+ wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
+
+ var (
+ diskKeys [][]byte
+ memKeys []types.Hash
+ )
+ if memonly {
+ memKeys = triedb.Nodes()
+ } else {
+ it := diskdb.NewIterator(nil, nil)
+ for it.Next() {
+ diskKeys = append(diskKeys, it.Key())
+ }
+ it.Release()
+ }
+ for i := 0; i < 20; i++ {
+ // Create trie that will load all nodes from DB.
+ tr, _ := New(TrieID(tr.Hash()), triedb, logger)
+
+ // Remove a random node from the database. It can't be the root node
+ // because that one is already loaded.
+ var (
+ rkey types.Hash
+ rval []byte
+ robj *cachedNode
+ )
+ for {
+ if memonly {
+ rkey = memKeys[rand.Intn(len(memKeys))]
+ } else {
+ copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
+ }
+ if rkey != tr.Hash() {
+ break
+ }
+ }
+ if memonly {
+ robj = triedb.dirties[rkey]
+ delete(triedb.dirties, rkey)
+ } else {
+ rval, _, _ = diskdb.Get(rkey[:])
+ diskdb.Delete(rkey[:])
+ }
+ // Iterate until the error is hit.
+ seen := make(map[string]bool)
+ it := tr.NodeIterator(nil)
+ checkIteratorNoDups(t, it, seen)
+ missing, ok := it.Error().(*MissingNodeError)
+ if !ok || missing.NodeHash != rkey {
+ t.Fatal("didn't hit missing node, got", it.Error())
+ }
+
+ // Add the node back and continue iteration.
+ if memonly {
+ triedb.dirties[rkey] = robj
+ } else {
+ diskdb.Set(rkey[:], rval)
+ }
+ checkIteratorNoDups(t, it, seen)
+ if it.Error() != nil {
+ t.Fatal("unexpected error", it.Error())
+ }
+ if len(seen) != wantNodeCount {
+ t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
+ }
+ }
+}
+
+// Similar to the test above, this one checks that failure to create nodeIterator at a
+// certain key prefix behaves correctly when Next is called. The expectation is that Next
+// should retry seeking before returning true for the first time.
+func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, false)
+}
+func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
+ testIteratorContinueAfterSeekError(t, true)
+}
+
+func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
+ t.Helper()
+
+ // Commit test trie to db, then remove the node containing "bars".
+ diskdb := rawdb.NewMemoryDatabase()
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(diskdb, logger)
+
+ ctr := NewEmpty(triedb)
+ for _, val := range testdata1 {
+ ctr.Update([]byte(val.k), []byte(val.v))
+ }
+ root, nodes, _ := ctr.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ if !memonly {
+ triedb.Commit(root, true, nil)
+ }
+ barNodeHash := types.StringToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
+ var (
+ barNodeBlob []byte
+ barNodeObj *cachedNode
+ )
+ if memonly {
+ barNodeObj = triedb.dirties[barNodeHash]
+ delete(triedb.dirties, barNodeHash)
+ } else {
+ barNodeBlob, _, _ = diskdb.Get(barNodeHash[:])
+ diskdb.Delete(barNodeHash[:])
+ }
+ // Create a new iterator that seeks to "bars". Seeking can't proceed because
+ // the node is missing.
+ tr, _ := New(TrieID(root), triedb, logger)
+ it := tr.NodeIterator([]byte("bars"))
+ missing, ok := it.Error().(*MissingNodeError)
+ if !ok {
+ t.Fatal("want MissingNodeError, got", it.Error())
+ } else if missing.NodeHash != barNodeHash {
+ t.Fatal("wrong node missing")
+ }
+ // Reinsert the missing node.
+ if memonly {
+ triedb.dirties[barNodeHash] = barNodeObj
+ } else {
+ diskdb.Set(barNodeHash[:], barNodeBlob)
+ }
+ // Check that iteration produces the right set of values.
+ if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
+ t.Helper()
+
+ if seen == nil {
+ seen = make(map[string]bool)
+ }
+ for it.Next(true) {
+ if seen[string(it.Path())] {
+ t.Fatalf("iterator visited node path %x twice", it.Path())
+ }
+ seen[string(it.Path())] = true
+ }
+ return len(seen)
+}
+
+type loggingDB struct {
+ getCount uint64
+ backend kvdb.KVBatchStorage
+}
+
+func (l *loggingDB) Has(key []byte) (bool, error) {
+ return l.backend.Has(key)
+}
+
+func (l *loggingDB) Get(key []byte) ([]byte, bool, error) {
+ l.getCount++
+
+ return l.backend.Get(key)
+}
+
+func (l *loggingDB) Set(key []byte, value []byte) error {
+ return l.backend.Set(key, value)
+}
+
+func (l *loggingDB) Delete(key []byte) error {
+ return l.backend.Delete(key)
+}
+
+func (l *loggingDB) NewBatch() kvdb.Batch {
+ return l.backend.NewBatch()
+}
+
+func (l *loggingDB) NewIterator(prefix []byte, start []byte) kvdb.Iterator {
+ return l.backend.NewIterator(prefix, start)
+}
+
+func (l *loggingDB) Close() error {
+ return l.backend.Close()
+}
+
+// makeLargeTestTrie create a sample test trie
+func makeLargeTestTrie() (*Database, *StateTrie, *loggingDB) {
+ // Create an empty trie
+ logDB := &loggingDB{0, memorydb.New()}
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(rawdb.NewDatabase(logDB), logger)
+ trie, _ := NewStateTrie(TrieID(types.Hash{}), triedb, logger)
+
+ // Fill it with some arbitrary data
+ for i := 0; i < 10000; i++ {
+ key := make([]byte, 32)
+ val := make([]byte, 32)
+ binary.BigEndian.PutUint64(key, uint64(i))
+ binary.BigEndian.PutUint64(val, uint64(i))
+ key = crypto.Keccak256(key)
+ val = crypto.Keccak256(val)
+ trie.Update(key, val)
+ }
+ _, nodes, _ := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ // Return the generated trie
+ return triedb, trie, logDB
+}
+
+// Tests that the node iterator indeed walks over the entire database contents.
+func TestNodeIteratorLargeTrie(t *testing.T) {
+ // Create some arbitrary test trie to iterate
+ db, trie, logDB := makeLargeTestTrie()
+ db.Cap(0) // flush everything
+ // Do a seek operation
+ trie.NodeIterator(types.StringToBytes("0x77667766776677766778855885885885"))
+ // master: 24 get operations
+ // this pr: 5 get operations
+ if have, want := logDB.getCount, uint64(5); have != want {
+ t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
+ }
+}
+
+func TestIteratorNodeBlob(t *testing.T) {
+ var (
+ db = rawdb.NewMemoryDatabase()
+ triedb = NewDatabase(db, hclog.NewNullLogger())
+ trie = NewEmpty(triedb)
+ )
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ all := make(map[string]string)
+ for _, val := range vals {
+ all[val.k] = val.v
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ _, nodes, _ := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ triedb.Cap(0)
+
+ found := make(map[types.Hash][]byte)
+ it := trie.NodeIterator(nil)
+ for it.Next(true) {
+ if it.Hash() == (types.Hash{}) {
+ continue
+ }
+ found[it.Hash()] = it.NodeBlob()
+ }
+
+ dbIter := db.NewIterator(nil, nil)
+ defer dbIter.Release()
+
+ var count int
+ for dbIter.Next() {
+ got, present := found[types.BytesToHash(dbIter.Key())]
+ if !present {
+ t.Fatalf("Miss trie node %v", dbIter.Key())
+ }
+ if !bytes.Equal(got, dbIter.Value()) {
+ t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
+ }
+ count += 1
+ }
+ if count != len(found) {
+ t.Fatal("Find extra trie node via iterator")
+ }
+}
diff --git a/trie/logger.go b/trie/logger.go
new file mode 100644
index 0000000000..31b0a77069
--- /dev/null
+++ b/trie/logger.go
@@ -0,0 +1,16 @@
+package trie
+
+// Logger describes the interface that must be implemented by all loggers.
+type Logger interface {
+ // Emit a message and key/value pairs at the DEBUG level
+ Debug(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the INFO level
+ Info(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the WARN level
+ Warn(msg string, args ...interface{})
+
+ // Emit a message and key/value pairs at the ERROR level
+ Error(msg string, args ...interface{})
+}
diff --git a/trie/node.go b/trie/node.go
new file mode 100644
index 0000000000..ceee993697
--- /dev/null
+++ b/trie/node.go
@@ -0,0 +1,280 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+ "io"
+ "strings"
+
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
+
+type node interface {
+ cache() (hashNode, bool)
+ encode(w rlp.EncoderBuffer)
+ fstring(string) string
+}
+
+type (
+ fullNode struct {
+ Children [17]node // Actual trie node data to encode/decode (needs custom encoder)
+ flags nodeFlag
+ }
+ shortNode struct {
+ Key []byte
+ Val node
+ flags nodeFlag
+ }
+ hashNode []byte
+ valueNode []byte
+)
+
+// nilValueNode is used when collapsing internal trie nodes for hashing, since
+// unset children need to serialize correctly.
+var nilValueNode = valueNode(nil)
+
+// EncodeRLP encodes a full node into the consensus RLP format.
+func (n *fullNode) EncodeRLP(w io.Writer) error {
+ eb := rlp.NewEncoderBuffer(w)
+ n.encode(eb)
+
+ return eb.Flush()
+}
+
+func (n *fullNode) copy() *fullNode {
+ copyd := *n
+
+ return ©d
+}
+
+func (n *shortNode) copy() *shortNode {
+ copyd := *n
+
+ return ©d
+}
+
+// nodeFlag contains caching-related metadata about a node.
+type nodeFlag struct {
+ hash hashNode // cached hash of the node (may be nil)
+ dirty bool // whether the node has changes that must be written to the database
+}
+
+func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
+func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
+func (n hashNode) cache() (hashNode, bool) { return nil, true }
+func (n valueNode) cache() (hashNode, bool) { return nil, true }
+
+// Pretty printing.
+func (n *fullNode) String() string { return n.fstring("") }
+func (n *shortNode) String() string { return n.fstring("") }
+func (n hashNode) String() string { return n.fstring("") }
+func (n valueNode) String() string { return n.fstring("") }
+
+func (n *fullNode) fstring(ind string) string {
+ resp := fmt.Sprintf("[\n%s ", ind)
+
+ for i, node := range &n.Children {
+ if node == nil {
+ resp += fmt.Sprintf("%s: ", indices[i])
+ } else {
+ resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+" "))
+ }
+ }
+
+ return resp + fmt.Sprintf("\n%s] ", ind)
+}
+
+func (n *shortNode) fstring(ind string) string {
+ return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
+}
+
+func (n hashNode) fstring(ind string) string {
+ return fmt.Sprintf("<%x> ", []byte(n))
+}
+
+func (n valueNode) fstring(ind string) string {
+ return fmt.Sprintf("%x ", []byte(n))
+}
+
+// mustDecodeNode is a wrapper of decodeNode and panic if any error is encountered.
+func mustDecodeNode(hash, buf []byte) node {
+ n, err := decodeNode(hash, buf)
+ if err != nil {
+ panic(fmt.Sprintf("node %x: %v", hash, err))
+ }
+
+ return n
+}
+
+// mustDecodeNodeUnsafe is a wrapper of decodeNodeUnsafe and panic if any error is
+// encountered.
+func mustDecodeNodeUnsafe(hash, buf []byte) node {
+ n, err := decodeNodeUnsafe(hash, buf)
+ if err != nil {
+ panic(fmt.Sprintf("node %x: %v", hash, err))
+ }
+
+ return n
+}
+
+// decodeNode parses the RLP encoding of a trie node. It will deep-copy the passed
+// byte slice for decoding, so it's safe to modify the byte slice afterwards. The-
+// decode performance of this function is not optimal, but it is suitable for most
+// scenarios with low performance requirements and hard to determine whether the
+// byte slice be modified or not.
+func decodeNode(hash, buf []byte) (node, error) {
+ return decodeNodeUnsafe(hash, types.CopyBytes(buf))
+}
+
+// decodeNodeUnsafe parses the RLP encoding of a trie node. The passed byte slice
+// will be directly referenced by node without bytes deep copy, so the input MUST
+// not be changed after.
+func decodeNodeUnsafe(hash, buf []byte) (node, error) {
+ if len(buf) == 0 {
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ elems, _, err := rlp.SplitList(buf)
+ if err != nil {
+ return nil, fmt.Errorf("decode error: %w", err)
+ }
+
+ switch c, _ := rlp.CountValues(elems); c {
+ case 2:
+ n, err := decodeShort(hash, elems)
+
+ return n, wrapError(err, "short")
+ case 17:
+ n, err := decodeFull(hash, elems)
+
+ return n, wrapError(err, "full")
+ default:
+ return nil, fmt.Errorf("invalid number of list elements: %v", c)
+ }
+}
+
+func decodeShort(hash, elems []byte) (node, error) {
+ kbuf, rest, err := rlp.SplitString(elems)
+ if err != nil {
+ return nil, err
+ }
+
+ flag := nodeFlag{hash: hash}
+
+ key := compactToHex(kbuf)
+ if hasTerm(key) {
+ // value node
+ val, _, err := rlp.SplitString(rest)
+ if err != nil {
+ return nil, fmt.Errorf("invalid value node: %w", err)
+ }
+
+ return &shortNode{key, valueNode(val), flag}, nil
+ }
+
+ r, _, err := decodeRef(rest)
+ if err != nil {
+ return nil, wrapError(err, "val")
+ }
+
+ return &shortNode{key, r, flag}, nil
+}
+
+func decodeFull(hash, elems []byte) (*fullNode, error) {
+ n := &fullNode{flags: nodeFlag{hash: hash}}
+
+ for i := 0; i < 16; i++ {
+ cld, rest, err := decodeRef(elems)
+ if err != nil {
+ return n, wrapError(err, fmt.Sprintf("[%d]", i))
+ }
+
+ n.Children[i], elems = cld, rest
+ }
+
+ val, _, err := rlp.SplitString(elems)
+ if err != nil {
+ return n, err
+ }
+
+ if len(val) > 0 {
+ n.Children[16] = valueNode(val)
+ }
+
+ return n, nil
+}
+
+const hashLen = len(types.Hash{})
+
+func decodeRef(buf []byte) (node, []byte, error) {
+ kind, val, rest, err := rlp.Split(buf)
+ if err != nil {
+ return nil, buf, err
+ }
+
+ switch {
+ case kind == rlp.List:
+ // 'embedded' node reference. The encoding must be smaller
+ // than a hash in order to be valid.
+ if size := len(buf) - len(rest); size > hashLen {
+ err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
+
+ return nil, buf, err
+ }
+
+ n, err := decodeNode(nil, buf)
+
+ return n, rest, err
+ case kind == rlp.String && len(val) == 0:
+ // empty node
+ return nil, rest, nil
+ case kind == rlp.String && len(val) == 32:
+ return hashNode(val), rest, nil
+ default:
+ return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
+ }
+}
+
+// wraps a decoding error with information about the path to the
+// invalid child node (for debugging encoding issues).
+type decodeError struct {
+ what error
+ stack []string
+}
+
+func wrapError(err error, ctx string) error {
+ if err == nil {
+ return nil
+ }
+
+ //nolint:errorlint
+ if decErr, ok := err.(*decodeError); ok {
+ decErr.stack = append(decErr.stack, ctx)
+
+ return decErr
+ }
+
+ return &decodeError{err, []string{ctx}}
+}
+
+func (err *decodeError) Error() string {
+ return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-"))
+}
diff --git a/trie/node_enc.go b/trie/node_enc.go
new file mode 100644
index 0000000000..d0d77e613c
--- /dev/null
+++ b/trie/node_enc.go
@@ -0,0 +1,94 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import "github.com/dogechain-lab/dogechain/helper/rlp"
+
+func nodeToBytes(n node) []byte {
+ w := rlp.NewEncoderBuffer(nil)
+ n.encode(w)
+ result := w.ToBytes()
+ w.Flush()
+
+ return result
+}
+
+func (n *fullNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+
+ for _, c := range n.Children {
+ if c != nil {
+ c.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ }
+
+ w.ListEnd(offset)
+}
+
+func (n *shortNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ w.WriteBytes(n.Key)
+
+ if n.Val != nil {
+ n.Val.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+
+ w.ListEnd(offset)
+}
+
+func (n hashNode) encode(w rlp.EncoderBuffer) {
+ w.WriteBytes(n)
+}
+
+func (n valueNode) encode(w rlp.EncoderBuffer) {
+ w.WriteBytes(n)
+}
+
+func (n rawFullNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+
+ for _, c := range n {
+ if c != nil {
+ c.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+ }
+
+ w.ListEnd(offset)
+}
+
+func (n *rawShortNode) encode(w rlp.EncoderBuffer) {
+ offset := w.List()
+ w.WriteBytes(n.Key)
+
+ if n.Val != nil {
+ n.Val.encode(w)
+ } else {
+ w.Write(rlp.EmptyString)
+ }
+
+ w.ListEnd(offset)
+}
+
+func (n rawNode) encode(w rlp.EncoderBuffer) {
+ w.Write(n)
+}
diff --git a/trie/node_test.go b/trie/node_test.go
new file mode 100644
index 0000000000..09a19ad311
--- /dev/null
+++ b/trie/node_test.go
@@ -0,0 +1,215 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/rlp"
+)
+
+func newTestFullNode(v []byte) []interface{} {
+ fullNodeData := []interface{}{}
+ for i := 0; i < 16; i++ {
+ k := bytes.Repeat([]byte{byte(i + 1)}, 32)
+ fullNodeData = append(fullNodeData, k)
+ }
+ fullNodeData = append(fullNodeData, v)
+ return fullNodeData
+}
+
+func TestDecodeNestedNode(t *testing.T) {
+ fullNodeData := newTestFullNode([]byte("fullnode"))
+
+ data := [][]byte{}
+ for i := 0; i < 16; i++ {
+ data = append(data, nil)
+ }
+ data = append(data, []byte("subnode"))
+ fullNodeData[15] = data
+
+ buf := bytes.NewBuffer([]byte{})
+ rlp.Encode(buf, fullNodeData)
+
+ if _, err := decodeNode([]byte("testdecode"), buf.Bytes()); err != nil {
+ t.Fatalf("decode nested full node err: %v", err)
+ }
+}
+
+func TestDecodeFullNodeWrongSizeChild(t *testing.T) {
+ fullNodeData := newTestFullNode([]byte("wrongsizechild"))
+ fullNodeData[0] = []byte("00")
+ buf := bytes.NewBuffer([]byte{})
+ rlp.Encode(buf, fullNodeData)
+
+ _, err := decodeNode([]byte("testdecode"), buf.Bytes())
+ if _, ok := err.(*decodeError); !ok {
+ t.Fatalf("decodeNode returned wrong err: %v", err)
+ }
+}
+
+func TestDecodeFullNodeWrongNestedFullNode(t *testing.T) {
+ fullNodeData := newTestFullNode([]byte("fullnode"))
+
+ data := [][]byte{}
+ for i := 0; i < 16; i++ {
+ data = append(data, []byte("123456"))
+ }
+ data = append(data, []byte("subnode"))
+ fullNodeData[15] = data
+
+ buf := bytes.NewBuffer([]byte{})
+ rlp.Encode(buf, fullNodeData)
+
+ _, err := decodeNode([]byte("testdecode"), buf.Bytes())
+ if _, ok := err.(*decodeError); !ok {
+ t.Fatalf("decodeNode returned wrong err: %v", err)
+ }
+}
+
+func TestDecodeFullNode(t *testing.T) {
+ fullNodeData := newTestFullNode([]byte("decodefullnode"))
+ buf := bytes.NewBuffer([]byte{})
+ rlp.Encode(buf, fullNodeData)
+
+ _, err := decodeNode([]byte("testdecode"), buf.Bytes())
+ if err != nil {
+ t.Fatalf("decode full node err: %v", err)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkEncodeShortNode
+// BenchmarkEncodeShortNode-8 16878850 70.81 ns/op 48 B/op 1 allocs/op
+func BenchmarkEncodeShortNode(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ nodeToBytes(node)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkEncodeFullNode
+// BenchmarkEncodeFullNode-8 4323273 284.4 ns/op 576 B/op 1 allocs/op
+func BenchmarkEncodeFullNode(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ nodeToBytes(node)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeShortNode
+// BenchmarkDecodeShortNode-8 7925638 151.0 ns/op 157 B/op 4 allocs/op
+func BenchmarkDecodeShortNode(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNode(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeShortNodeUnsafe
+// BenchmarkDecodeShortNodeUnsafe-8 9027476 128.6 ns/op 109 B/op 3 allocs/op
+func BenchmarkDecodeShortNodeUnsafe(b *testing.B) {
+ node := &shortNode{
+ Key: []byte{0x1, 0x2},
+ Val: hashNode(randBytes(32)),
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNodeUnsafe(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeFullNode
+// BenchmarkDecodeFullNode-8 1597462 761.9 ns/op 1280 B/op 18 allocs/op
+func BenchmarkDecodeFullNode(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNode(hash, blob)
+ }
+}
+
+// goos: darwin
+// goarch: arm64
+// pkg: github.com/ethereum/go-ethereum/trie
+// BenchmarkDecodeFullNodeUnsafe
+// BenchmarkDecodeFullNodeUnsafe-8 1789070 687.1 ns/op 704 B/op 17 allocs/op
+func BenchmarkDecodeFullNodeUnsafe(b *testing.B) {
+ node := &fullNode{}
+ for i := 0; i < 16; i++ {
+ node.Children[i] = hashNode(randBytes(32))
+ }
+ blob := nodeToBytes(node)
+ hash := crypto.Keccak256(blob)
+
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ mustDecodeNodeUnsafe(hash, blob)
+ }
+}
diff --git a/trie/nodeset.go b/trie/nodeset.go
new file mode 100644
index 0000000000..259fff14be
--- /dev/null
+++ b/trie/nodeset.go
@@ -0,0 +1,226 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// memoryNode is all the information we know about a single cached trie node
+// in the memory.
+type memoryNode struct {
+ hash types.Hash // Node hash, computed by hashing rlp value, empty for deleted nodes
+ size uint16 // Byte size of the useful cached data, 0 for deleted nodes
+ node node // Cached collapsed trie node, or raw rlp data, nil for deleted nodes
+}
+
+// memoryNodeSize is the raw size of a memoryNode data structure without any
+// node data included. It's an approximate size, but should be a lot better
+// than not counting them.
+//
+//nolint:unused
+var memoryNodeSize = int(reflect.TypeOf(memoryNode{}).Size())
+
+// memorySize returns the total memory size used by this node.
+//
+//nolint:unused
+func (n *memoryNode) memorySize(key int) int {
+ return int(n.size) + memoryNodeSize + key
+}
+
+// rlp returns the raw rlp encoded blob of the cached trie node, either directly
+// from the cache, or by regenerating it from the collapsed node.
+//
+//nolint:unused
+func (n *memoryNode) rlp() []byte {
+ if node, ok := n.node.(rawNode); ok {
+ return node
+ }
+
+ return nodeToBytes(n.node)
+}
+
+// obj returns the decoded and expanded trie node, either directly from the cache,
+// or by regenerating it from the rlp encoded blob.
+//
+//nolint:unused
+func (n *memoryNode) obj() node {
+ if node, ok := n.node.(rawNode); ok {
+ return mustDecodeNode(n.hash[:], node)
+ }
+
+ return expandNode(n.hash[:], n.node)
+}
+
+// nodeWithPrev wraps the memoryNode with the previous node value.
+type nodeWithPrev struct {
+ *memoryNode
+ prev []byte // RLP-encoded previous value, nil means it's non-existent
+}
+
+// unwrap returns the internal memoryNode object.
+//
+//nolint:unused
+func (n *nodeWithPrev) unwrap() *memoryNode {
+ return n.memoryNode
+}
+
+// memorySize returns the total memory size used by this node. It overloads
+// the function in memoryNode by counting the size of previous value as well.
+//
+//nolint:unused
+func (n *nodeWithPrev) memorySize(key int) int {
+ return n.memoryNode.memorySize(key) + len(n.prev)
+}
+
+// nodesWithOrder represents a collection of dirty nodes which includes
+// newly-inserted and updated nodes. The modification order of all nodes
+// is represented by order list.
+type nodesWithOrder struct {
+ order []string // the path list of dirty nodes, sort by insertion order
+ nodes map[string]*nodeWithPrev // the map of dirty nodes, keyed by node path
+}
+
+// NodeSet contains all dirty nodes collected during the commit operation.
+// Each node is keyed by path. It's not thread-safe to use.
+type NodeSet struct {
+ owner types.Hash // the identifier of the trie
+ updates *nodesWithOrder // the set of updated nodes(newly inserted, updated)
+ deletes map[string][]byte // the map of deleted nodes, keyed by node
+ leaves []*leaf // the list of dirty leaves
+}
+
+// NewNodeSet initializes an empty node set to be used for tracking dirty nodes
+// from a specific account or storage trie. The owner is zero for the account
+// trie and the owning account address hash for storage tries.
+func NewNodeSet(owner types.Hash) *NodeSet {
+ return &NodeSet{
+ owner: owner,
+ updates: &nodesWithOrder{
+ nodes: make(map[string]*nodeWithPrev),
+ },
+ deletes: make(map[string][]byte),
+ }
+}
+
+// // NewNodeSetWithDeletion initializes the nodeset with provided deletion set.
+// func NewNodeSetWithDeletion(owner types.Hash, paths [][]byte, prev [][]byte) *NodeSet {
+// set := NewNodeSet(owner)
+// for i, path := range paths {
+// set.markDeleted(path, prev[i])
+// }
+// return set
+// }
+
+// markUpdated marks the node as dirty(newly-inserted or updated) with provided
+// node path, node object along with its previous value.
+func (set *NodeSet) markUpdated(path []byte, node *memoryNode, prev []byte) {
+ set.updates.order = append(set.updates.order, string(path))
+ set.updates.nodes[string(path)] = &nodeWithPrev{
+ memoryNode: node,
+ prev: prev,
+ }
+}
+
+// markDeleted marks the node as deleted with provided path and previous value.
+func (set *NodeSet) markDeleted(path []byte, prev []byte) {
+ set.deletes[string(path)] = prev
+}
+
+// addLeaf collects the provided leaf node into set.
+func (set *NodeSet) addLeaf(node *leaf) {
+ set.leaves = append(set.leaves, node)
+}
+
+// Size returns the number of updated and deleted nodes contained in the set.
+func (set *NodeSet) Size() (int, int) {
+ return len(set.updates.order), len(set.deletes)
+}
+
+// Hashes returns the hashes of all updated nodes. TODO(rjl493456442) how can
+// we get rid of it?
+func (set *NodeSet) Hashes() []types.Hash {
+ ret := make([]types.Hash, 0, len(set.updates.nodes))
+ for _, node := range set.updates.nodes {
+ ret = append(ret, node.hash)
+ }
+
+ return ret
+}
+
+// Summary returns a string-representation of the NodeSet.
+func (set *NodeSet) Summary() string {
+ var out = new(strings.Builder)
+
+ fmt.Fprintf(out, "nodeset owner: %v\n", set.owner)
+
+ if set.updates != nil {
+ for _, key := range set.updates.order {
+ updated := set.updates.nodes[key]
+ if updated.prev != nil {
+ fmt.Fprintf(out, " [*]: %x -> %v prev: %x\n", key, updated.hash, updated.prev)
+ } else {
+ fmt.Fprintf(out, " [+]: %x -> %v\n", key, updated.hash)
+ }
+ }
+ }
+
+ for k, n := range set.deletes {
+ fmt.Fprintf(out, " [-]: %x -> %x\n", k, n)
+ }
+
+ for _, n := range set.leaves {
+ fmt.Fprintf(out, "[leaf]: %v\n", n)
+ }
+
+ return out.String()
+}
+
+// MergedNodeSet represents a merged dirty node set for a group of tries.
+type MergedNodeSet struct {
+ sets map[types.Hash]*NodeSet
+}
+
+// NewMergedNodeSet initializes an empty merged set.
+func NewMergedNodeSet() *MergedNodeSet {
+ return &MergedNodeSet{sets: make(map[types.Hash]*NodeSet)}
+}
+
+// NewWithNodeSet constructs a merged nodeset with the provided single set.
+func NewWithNodeSet(set *NodeSet) *MergedNodeSet {
+ merged := NewMergedNodeSet()
+ merged.Merge(set)
+
+ return merged
+}
+
+// Merge merges the provided dirty nodes of a trie into the set. The assumption
+// is held that no duplicated set belonging to the same trie will be merged twice.
+func (set *MergedNodeSet) Merge(other *NodeSet) error {
+ _, present := set.sets[other.owner]
+ if present {
+ return fmt.Errorf("duplicate trie for owner %#x", other.owner)
+ }
+
+ set.sets[other.owner] = other
+
+ return nil
+}
diff --git a/trie/proof.go b/trie/proof.go
new file mode 100644
index 0000000000..42fd1fa59d
--- /dev/null
+++ b/trie/proof.go
@@ -0,0 +1,714 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
+//
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *Trie) Prove(key []byte, fromLevel uint, proofDB kvdb.KVWriter) error {
+ // Collect all nodes on the path to key.
+ var (
+ prefix []byte
+ nodes []node
+ tn = t.root
+ )
+
+ key = keybytesToHex(key)
+
+ for len(key) > 0 && tn != nil {
+ switch n := tn.(type) {
+ case *shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ // The trie doesn't contain the key.
+ tn = nil
+ } else {
+ tn = n.Val
+ prefix = append(prefix, n.Key...)
+ key = key[len(n.Key):]
+ }
+
+ nodes = append(nodes, n)
+ case *fullNode:
+ tn = n.Children[key[0]]
+ prefix = append(prefix, key[0])
+ key = key[1:]
+
+ nodes = append(nodes, n)
+ case hashNode:
+ // Retrieve the specified node from the underlying node reader.
+ // trie.resolveAndTrack is not used since in that function the
+ // loaded blob will be tracked, while it's not required here since
+ // all loaded nodes won't be linked to trie at all and track nodes
+ // may lead to out-of-memory issue.
+ var err error
+
+ tn, err = t.reader.node(prefix, types.BytesToHash(n))
+ if err != nil {
+ return err
+ }
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
+ }
+
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ for i, n := range nodes {
+ if fromLevel > 0 {
+ fromLevel--
+
+ continue
+ }
+
+ var hn node
+ n, hn = hasher.proofHash(n)
+
+ if hash, ok := hn.(hashNode); ok || i == 0 {
+ // If the node's database encoding is a hash (or is the
+ // root node), it becomes a proof element.
+ enc := nodeToBytes(n)
+
+ if !ok {
+ hash = hasher.hashData(enc)
+ }
+
+ proofDB.Set(hash, enc)
+ }
+ }
+
+ return nil
+}
+
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
+//
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDB kvdb.KVWriter) error {
+ return t.trie.Prove(key, fromLevel, proofDB)
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the value for
+// key in a trie with the given root hash. VerifyProof returns an error if the
+// proof contains invalid trie nodes or the wrong value.
+func VerifyProof(rootHash types.Hash, key []byte, proofDB kvdb.KVReader) (value []byte, err error) {
+ key = keybytesToHex(key)
+ wantHash := rootHash
+
+ for i := 0; ; i++ {
+ buf, _, _ := proofDB.Get(wantHash[:])
+ if buf == nil {
+ return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
+ }
+
+ n, err := decodeNode(wantHash[:], buf)
+ if err != nil {
+ return nil, fmt.Errorf("bad proof node %d: %w", i, err)
+ }
+
+ keyrest, cld := get(n, key, true)
+
+ switch cld := cld.(type) {
+ case nil:
+ // The trie doesn't contain the key.
+ return nil, nil
+ case hashNode:
+ key = keyrest
+
+ copy(wantHash[:], cld)
+ case valueNode:
+ return cld, nil
+ }
+ }
+}
+
+// proofToPath converts a merkle proof to trie node path. The main purpose of
+// this function is recovering a node path from the merkle proof stream. All
+// necessary nodes will be resolved and leave the remaining as hashnode.
+//
+// The given edge proof is allowed to be an existent or non-existent proof.
+func proofToPath(
+ rootHash types.Hash,
+ root node,
+ key []byte,
+ proofDB kvdb.KVReader,
+ allowNonExistent bool,
+) (node, []byte, error) {
+ // resolveNode retrieves and resolves trie node from merkle proof stream
+ resolveNode := func(hash types.Hash) (node, error) {
+ buf, _, _ := proofDB.Get(hash[:])
+ if buf == nil {
+ return nil, fmt.Errorf("proof node (hash %064x) missing", hash)
+ }
+
+ n, err := decodeNode(hash[:], buf)
+ if err != nil {
+ return nil, fmt.Errorf("bad proof node %w", err)
+ }
+
+ return n, err
+ }
+
+ // If the root node is empty, resolve it first.
+ // Root node must be included in the proof.
+ if root == nil {
+ n, err := resolveNode(rootHash)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ root = n
+ }
+
+ var (
+ err error
+ child, parent node
+ keyrest []byte
+ valnode []byte
+ )
+
+ key, parent = keybytesToHex(key), root
+
+ for {
+ keyrest, child = get(parent, key, false)
+
+ switch cld := child.(type) {
+ case nil:
+ // The trie doesn't contain the key. It's possible
+ // the proof is a non-existing proof, but at least
+ // we can prove all resolved nodes are correct, it's
+ // enough for us to prove range.
+ if allowNonExistent {
+ return root, nil, nil
+ }
+
+ return nil, nil, errors.New("the node is not contained in trie")
+ case *shortNode:
+ key, parent = keyrest, child // Already resolved
+
+ continue
+ case *fullNode:
+ key, parent = keyrest, child // Already resolved
+
+ continue
+ case hashNode:
+ child, err = resolveNode(types.BytesToHash(cld))
+ if err != nil {
+ return nil, nil, err
+ }
+ case valueNode:
+ valnode = cld
+ }
+
+ // Link the parent and child.
+ switch pnode := parent.(type) {
+ case *shortNode:
+ pnode.Val = child
+ case *fullNode:
+ pnode.Children[key[0]] = child
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode))
+ }
+
+ if len(valnode) > 0 {
+ return root, valnode, nil // The whole path is resolved
+ }
+
+ key, parent = keyrest, child
+ }
+}
+
+// unsetInternal removes all internal node references(hashnode, embedded node).
+// It should be called after a trie is constructed with two edge paths. Also
+// the given boundary keys must be the one used to construct the edge paths.
+//
+// It's the key step for range proof. All visited nodes should be marked dirty
+// since the node content might be modified. Besides it can happen that some
+// fullnodes only have one child which is disallowed. But if the proof is valid,
+// the missing children will be filled, otherwise it will be thrown anyway.
+//
+// Note we have the assumption here the given boundary keys are different
+// and right is larger than left.
+func unsetInternal(n node, left []byte, right []byte) (bool, error) {
+ left, right = keybytesToHex(left), keybytesToHex(right)
+
+ // Step down to the fork point. There are two scenarios can happen:
+ // - the fork point is a shortnode: either the key of left proof or
+ // right proof doesn't match with shortnode's key.
+ // - the fork point is a fullnode: both two edge proofs are allowed
+ // to point to a non-existent key.
+ var (
+ pos = 0
+ parent node
+
+ // fork indicator, 0 means no fork, -1 means proof is less, 1 means proof is greater
+ shortForkLeft, shortForkRight int
+ )
+
+findFork:
+ for {
+ switch rn := (n).(type) {
+ case *shortNode:
+ rn.flags = nodeFlag{dirty: true}
+
+ // If either the key of left proof or right proof doesn't match with
+ // shortnode, stop here and the forkpoint is the shortnode.
+ if len(left)-pos < len(rn.Key) {
+ shortForkLeft = bytes.Compare(left[pos:], rn.Key)
+ } else {
+ shortForkLeft = bytes.Compare(left[pos:pos+len(rn.Key)], rn.Key)
+ }
+
+ if len(right)-pos < len(rn.Key) {
+ shortForkRight = bytes.Compare(right[pos:], rn.Key)
+ } else {
+ shortForkRight = bytes.Compare(right[pos:pos+len(rn.Key)], rn.Key)
+ }
+
+ if shortForkLeft != 0 || shortForkRight != 0 {
+ break findFork
+ }
+
+ parent = n
+ n, pos = rn.Val, pos+len(rn.Key)
+ case *fullNode:
+ rn.flags = nodeFlag{dirty: true}
+
+ // If either the node pointed by left proof or right proof is nil,
+ // stop here and the forkpoint is the fullnode.
+ leftnode, rightnode := rn.Children[left[pos]], rn.Children[right[pos]]
+ if leftnode == nil || rightnode == nil || leftnode != rightnode {
+ break findFork
+ }
+
+ parent = n
+ n, pos = rn.Children[left[pos]], pos+1
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+ }
+
+ switch rn := n.(type) {
+ case *shortNode:
+ // There can have these five scenarios:
+ // - both proofs are less than the trie path => no valid range
+ // - both proofs are greater than the trie path => no valid range
+ // - left proof is less and right proof is greater => valid range, unset the shortnode entirely
+ // - left proof points to the shortnode, but right proof is greater
+ // - right proof points to the shortnode, but left proof is less
+ if shortForkLeft == -1 && shortForkRight == -1 {
+ return false, errors.New("empty range")
+ }
+
+ if shortForkLeft == 1 && shortForkRight == 1 {
+ return false, errors.New("empty range")
+ }
+
+ if shortForkLeft != 0 && shortForkRight != 0 {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+
+ //nolint:forcetypeassert
+ parent.(*fullNode).Children[left[pos-1]] = nil
+
+ return false, nil
+ }
+
+ // Only one proof points to non-existent key.
+ if shortForkRight != 0 {
+ if _, ok := rn.Val.(valueNode); ok {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+
+ //nolint:forcetypeassert
+ parent.(*fullNode).Children[left[pos-1]] = nil
+
+ return false, nil
+ }
+
+ return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false)
+ }
+
+ if shortForkLeft != 0 {
+ if _, ok := rn.Val.(valueNode); ok {
+ // The fork point is root node, unset the entire trie
+ if parent == nil {
+ return true, nil
+ }
+
+ //nolint:forcetypeassert
+ parent.(*fullNode).Children[right[pos-1]] = nil
+
+ return false, nil
+ }
+
+ return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true)
+ }
+
+ return false, nil
+ case *fullNode:
+ // unset all internal nodes in the forkpoint
+ for i := left[pos] + 1; i < right[pos]; i++ {
+ rn.Children[i] = nil
+ }
+
+ if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
+ return false, err
+ }
+
+ if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
+ return false, err
+ }
+
+ return false, nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// unset removes all internal node references either the left most or right most.
+// It can meet these scenarios:
+//
+// - The given path is existent in the trie, unset the associated nodes with the
+// specific direction
+// - The given path is non-existent in the trie
+// - the fork point is a fullnode, the corresponding child pointed by path
+// is nil, return
+// - the fork point is a shortnode, the shortnode is included in the range,
+// keep the entire branch and return.
+// - the fork point is a shortnode, the shortnode is excluded in the range,
+// unset the entire branch.
+func unset(parent node, child node, key []byte, pos int, removeLeft bool) error {
+ switch cld := child.(type) {
+ case *fullNode:
+ if removeLeft {
+ for i := 0; i < int(key[pos]); i++ {
+ cld.Children[i] = nil
+ }
+
+ cld.flags = nodeFlag{dirty: true}
+ } else {
+ for i := key[pos] + 1; i < 16; i++ {
+ cld.Children[i] = nil
+ }
+
+ cld.flags = nodeFlag{dirty: true}
+ }
+
+ return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft)
+ case *shortNode:
+ if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) {
+ // Find the fork point, it's an non-existent branch.
+ if removeLeft {
+ // If the key of fork shortnode is greater than the
+ // path(it doesn't belong to the range), keep
+ // it with the cached hash available. That means do
+ // nothing.
+ if bytes.Compare(cld.Key, key[pos:]) < 0 {
+ // The key of fork shortnode is less than the path
+ // (it belongs to the range), unset the entire
+ // branch. The parent must be a fullnode.
+ fn, _ := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+ }
+ } else {
+ // If the key of fork shortnode is less than the
+ // path(it doesn't belong to the range), keep
+ // it with the cached hash available. That means do
+ // nothing.
+ if bytes.Compare(cld.Key, key[pos:]) > 0 {
+ // The key of fork shortnode is greater than the
+ // path(it belongs to the range), unset the entrie
+ // branch. The parent must be a fullnode.
+ fn, _ := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+ }
+ }
+
+ return nil
+ }
+
+ if _, ok := cld.Val.(valueNode); ok {
+ fn, _ := parent.(*fullNode)
+ fn.Children[key[pos-1]] = nil
+
+ return nil
+ }
+
+ cld.flags = nodeFlag{dirty: true}
+
+ return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft)
+ case nil:
+ // If the node is nil, then it's a child of the fork point
+ // fullnode(it's a non-existent branch).
+ return nil
+ default:
+ panic("it shouldn't happen") // hashNode, valueNode
+ }
+}
+
+// hasRightElement returns the indicator whether there exists more elements
+// on the right side of the given path. The given path can point to an existent
+// key or a non-existent one. This function has the assumption that the whole
+// path should already be resolved.
+func hasRightElement(node node, key []byte) bool {
+ pos, key := 0, keybytesToHex(key)
+
+ for node != nil {
+ switch rn := node.(type) {
+ case *fullNode:
+ for i := key[pos] + 1; i < 16; i++ {
+ if rn.Children[i] != nil {
+ return true
+ }
+ }
+
+ node, pos = rn.Children[key[pos]], pos+1
+ case *shortNode:
+ if len(key)-pos < len(rn.Key) || !bytes.Equal(rn.Key, key[pos:pos+len(rn.Key)]) {
+ return bytes.Compare(rn.Key, key[pos:]) > 0
+ }
+
+ node, pos = rn.Val, pos+len(rn.Key)
+ case valueNode:
+ return false // We have resolved the whole path
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", node, node)) // hashnode
+ }
+ }
+
+ return false
+}
+
+// VerifyRangeProof checks whether the given leaf nodes and edge proof
+// can prove the given trie leaves range is matched with the specific root.
+// Besides, the range should be consecutive (no gap inside) and monotonic
+// increasing.
+//
+// Note the given proof actually contains two edge proofs. Both of them can
+// be non-existent proofs. For example the first proof is for a non-existent
+// key 0x03, the last proof is for a non-existent key 0x10. The given batch
+// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove the given
+// batch is valid.
+//
+// The firstKey is paired with firstProof, not necessarily the same as keys[0]
+// (unless firstProof is an existent proof). Similarly, lastKey and lastProof
+// are paired.
+//
+// Expect the normal case, this function can also be used to verify the following
+// range proofs:
+//
+// - All elements proof. In this case the proof can be nil, but the range should
+// be all the leaves in the trie.
+//
+// - One element proof. In this case no matter the edge proof is a non-existent
+// proof or not, we can always verify the correctness of the proof.
+//
+// - Zero element proof. In this case a single non-existent proof is enough to prove.
+// Besides, if there are still some other leaves available on the right side, then
+// an error will be returned.
+//
+// Except returning the error to indicate the proof is valid or not, the function will
+// also return a flag to indicate whether there exists more accounts/slots in the trie.
+//
+// Note: This method does not verify that the proof is of minimal form. If the input
+// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful'
+// data, then the proof will still be accepted.
+func VerifyRangeProof(
+ rootHash types.Hash,
+ firstKey []byte,
+ lastKey []byte,
+ keys [][]byte,
+ values [][]byte,
+ proof kvdb.KVReader,
+) (bool, error) {
+ if len(keys) != len(values) {
+ return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
+ }
+
+ // Ensure the received batch is monotonic increasing and contains no deletions
+ for i := 0; i < len(keys)-1; i++ {
+ if bytes.Compare(keys[i], keys[i+1]) >= 0 {
+ return false, errors.New("range is not monotonically increasing")
+ }
+ }
+
+ for _, value := range values {
+ if len(value) == 0 {
+ return false, errors.New("range contains deletion")
+ }
+ }
+
+ // Special case, there is no edge proof at all. The given range is expected
+ // to be the whole leaf-set in the trie.
+ if proof == nil {
+ tr := NewStackTrie(nil)
+
+ for index, key := range keys {
+ tr.TryUpdate(key, values[index])
+ }
+
+ if have, want := tr.Hash(), rootHash; have != want {
+ return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
+ }
+
+ return false, nil // No more elements
+ }
+
+ // Special case, there is a provided edge proof but zero key/value
+ // pairs, ensure there are no more accounts / slots in the trie.
+ if len(keys) == 0 {
+ root, val, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+
+ if val != nil || hasRightElement(root, firstKey) {
+ return false, errors.New("more entries available")
+ }
+
+ return false, nil
+ }
+
+ // Special case, there is only one element and two edge keys are same.
+ // In this case, we can't construct two edge paths. So handle it here.
+ if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
+ root, val, err := proofToPath(rootHash, nil, firstKey, proof, false)
+ if err != nil {
+ return false, err
+ }
+
+ if !bytes.Equal(firstKey, keys[0]) {
+ return false, errors.New("correct proof but invalid key")
+ }
+
+ if !bytes.Equal(val, values[0]) {
+ return false, errors.New("correct proof but invalid data")
+ }
+
+ return hasRightElement(root, firstKey), nil
+ }
+
+ // Ok, in all other cases, we require two edge paths available.
+ // First check the validity of edge keys.
+ if bytes.Compare(firstKey, lastKey) >= 0 {
+ return false, errors.New("invalid edge keys")
+ }
+
+ // todo(rjl493456442) different length edge keys should be supported
+ if len(firstKey) != len(lastKey) {
+ return false, errors.New("inconsistent edge keys")
+ }
+
+ // Convert the edge proofs to edge trie paths. Then we can
+ // have the same tree architecture with the original one.
+ // For the first edge proof, non-existent proof is allowed.
+ root, _, err := proofToPath(rootHash, nil, firstKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+
+ // Pass the root node here, the second path will be merged
+ // with the first one. For the last edge proof, non-existent
+ // proof is also allowed.
+ root, _, err = proofToPath(rootHash, root, lastKey, proof, true)
+ if err != nil {
+ return false, err
+ }
+
+ // Remove all internal references. All the removed parts should
+ // be re-filled(or re-constructed) by the given leaves range.
+ empty, err := unsetInternal(root, firstKey, lastKey)
+ if err != nil {
+ return false, err
+ }
+
+ // Rebuild the trie with the leaf stream, the shape of trie
+ // should be same with the original one.
+ tr := &Trie{root: root, reader: newEmptyReader()}
+ if empty {
+ tr.root = nil
+ }
+
+ for index, key := range keys {
+ tr.TryUpdate(key, values[index])
+ }
+
+ if tr.Hash() != rootHash {
+ return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
+ }
+
+ return hasRightElement(tr.root, keys[len(keys)-1]), nil
+}
+
+// get returns the child of the given node. Return nil if the
+// node with specified key doesn't exist at all.
+//
+// There is an additional flag `skipResolved`. If it's set then
+// all resolved nodes won't be returned.
+func get(tn node, key []byte, skipResolved bool) ([]byte, node) {
+ for {
+ switch n := tn.(type) {
+ case *shortNode:
+ if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+ return nil, nil
+ }
+
+ tn = n.Val
+ key = key[len(n.Key):]
+
+ if !skipResolved {
+ return key, tn
+ }
+ case *fullNode:
+ tn = n.Children[key[0]]
+ key = key[1:]
+
+ if !skipResolved {
+ return key, tn
+ }
+ case hashNode:
+ return key, n
+ case nil:
+ return key, nil
+ case valueNode:
+ return nil, n
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+ }
+ }
+}
diff --git a/trie/proof_test.go b/trie/proof_test.go
new file mode 100644
index 0000000000..30431c31e9
--- /dev/null
+++ b/trie/proof_test.go
@@ -0,0 +1,1108 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ mrand "math/rand"
+ "sort"
+ "testing"
+ "time"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/kvdb/memorydb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+func init() {
+ mrand.Seed(time.Now().Unix())
+}
+
+// makeProvers creates Merkle trie provers based on different implementations to
+// test all variations.
+func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
+ var provers []func(key []byte) *memorydb.Database
+
+ // Create a direct trie based Merkle prover
+ provers = append(provers, func(key []byte) *memorydb.Database {
+ proof := memorydb.New()
+ trie.Prove(key, 0, proof)
+ return proof
+ })
+ // Create a leaf iterator based Merkle prover
+ provers = append(provers, func(key []byte) *memorydb.Database {
+ proof := memorydb.New()
+ if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
+ for _, p := range it.Prove() {
+ proof.Set(crypto.Keccak256(p), p)
+ }
+ }
+ return proof
+ })
+ return provers
+}
+
+func TestProof(t *testing.T) {
+ trie, vals := randomTrie(500)
+ root := trie.Hash()
+ for i, prover := range makeProvers(trie) {
+ for _, kv := range vals {
+ proof := prover(kv.k)
+ if proof == nil {
+ t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
+ }
+ val, err := VerifyProof(root, kv.k, proof)
+ if err != nil {
+ t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
+ }
+ if !bytes.Equal(val, kv.v) {
+ t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
+ }
+ }
+ }
+}
+
+func TestOneElementProof(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ updateString(trie, "k", "v")
+ for i, prover := range makeProvers(trie) {
+ proof := prover([]byte("k"))
+ if proof == nil {
+ t.Fatalf("prover %d: nil proof", i)
+ }
+ if proof.Len() != 1 {
+ t.Errorf("prover %d: proof should have one element", i)
+ }
+ val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+ if err != nil {
+ t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
+ }
+ if !bytes.Equal(val, []byte("v")) {
+ t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
+ }
+ }
+}
+
+func TestBadProof(t *testing.T) {
+ trie, vals := randomTrie(800)
+ root := trie.Hash()
+ for i, prover := range makeProvers(trie) {
+ for _, kv := range vals {
+ proof := prover(kv.k)
+ if proof == nil {
+ t.Fatalf("prover %d: nil proof", i)
+ }
+ it := proof.NewIterator(nil, nil)
+ for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ {
+ it.Next()
+ }
+ key := it.Key()
+ val, _, _ := proof.Get(key)
+ proof.Delete(key)
+ it.Release()
+
+ mutateByte(val)
+ proof.Set(crypto.Keccak256(val), val)
+
+ if _, err := VerifyProof(root, kv.k, proof); err == nil {
+ t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
+ }
+ }
+ }
+}
+
+// Tests that missing keys can also be proven. The test explicitly uses a single
+// entry trie and checks for missing keys both before and after the single entry.
+func TestMissingKeyProof(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ updateString(trie, "k", "v")
+
+ for i, key := range []string{"a", "j", "l", "z"} {
+ proof := memorydb.New()
+ trie.Prove([]byte(key), 0, proof)
+
+ if proof.Len() != 1 {
+ t.Errorf("test %d: proof should have one element", i)
+ }
+ val, err := VerifyProof(trie.Hash(), []byte(key), proof)
+ if err != nil {
+ t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
+ }
+ if val != nil {
+ t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
+ }
+ }
+}
+
+type entrySlice []*kv
+
+func (p entrySlice) Len() int { return len(p) }
+func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
+func (p entrySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
+
+// TestRangeProof tests normal range proof with both edge proofs
+// as the existent proof. The test cases are generated randomly.
+func TestRangeProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+ for i := 0; i < 500; i++ {
+ start := mrand.Intn(len(entries))
+ end := mrand.Intn(len(entries)-start) + start + 1
+
+ proof := memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ if err != nil {
+ t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+ }
+ }
+}
+
+// TestRangeProof tests normal range proof with two non-existent proofs.
+// The test cases are generated randomly.
+func TestRangeProofWithNonExistentProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+ for i := 0; i < 500; i++ {
+ start := mrand.Intn(len(entries))
+ end := mrand.Intn(len(entries)-start) + start + 1
+ proof := memorydb.New()
+
+ // Short circuit if the decreased key is same with the previous key
+ first := decreaseKey(types.CopyBytes(entries[start].k))
+ if start != 0 && bytes.Equal(first, entries[start-1].k) {
+ continue
+ }
+ // Short circuit if the decreased key is underflow
+ if bytes.Compare(first, entries[start].k) > 0 {
+ continue
+ }
+ // Short circuit if the increased key is same with the next key
+ last := increaseKey(types.CopyBytes(entries[end-1].k))
+ if end != len(entries) && bytes.Equal(last, entries[end].k) {
+ continue
+ }
+ // Short circuit if the increased key is overflow
+ if bytes.Compare(last, entries[end-1].k) < 0 {
+ continue
+ }
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ if err != nil {
+ t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+ }
+ }
+ // Special case, two edge proofs for two edge key.
+ proof := memorydb.New()
+ first := types.StringToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
+ last := types.StringToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var k [][]byte
+ var v [][]byte
+ for i := 0; i < len(entries); i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+ if err != nil {
+ t.Fatal("Failed to verify whole rang with non-existent edges")
+ }
+}
+
+// TestRangeProofWithInvalidNonExistentProof tests such scenarios:
+// - There exists a gap between the first element and the left edge proof
+// - There exists a gap between the last element and the right edge proof
+func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ // Case 1
+ start, end := 100, 200
+ first := decreaseKey(types.CopyBytes(entries[start].k))
+
+ proof := memorydb.New()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ start = 105 // Gap created
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := start; i < end; i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
+ if err == nil {
+ t.Fatalf("Expected to detect the error, got nil")
+ }
+
+ // Case 2
+ start, end = 100, 200
+ last := increaseKey(types.CopyBytes(entries[end-1].k))
+ proof = memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ end = 195 // Capped slice
+ k = make([][]byte, 0)
+ v = make([][]byte, 0)
+ for i := start; i < end; i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
+ if err == nil {
+ t.Fatalf("Expected to detect the error, got nil")
+ }
+}
+
+// TestOneElementRangeProof tests the proof with only one
+// element. The first edge proof can be existent one or
+// non-existent one.
+func TestOneElementRangeProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ // One element with existent edge proof, both edge proofs
+ // point to the SAME key.
+ start := 1000
+ proof := memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // One element with left non-existent edge proof
+ start = 1000
+ first := decreaseKey(types.CopyBytes(entries[start].k))
+ proof = memorydb.New()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // One element with right non-existent edge proof
+ start = 1000
+ last := increaseKey(types.CopyBytes(entries[start].k))
+ proof = memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // One element with two non-existent edge proofs
+ start = 1000
+ first, last = decreaseKey(types.CopyBytes(entries[start].k)), increaseKey(types.CopyBytes(entries[start].k))
+ proof = memorydb.New()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // Test the mini trie with only a single element.
+ tinyTrie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ entry := &kv{randBytes(32), randBytes(20), false}
+ tinyTrie.Update(entry.k, entry.v)
+
+ first = types.StringToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
+ last = entry.k
+ proof = memorydb.New()
+ if err := tinyTrie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := tinyTrie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+}
+
+// TestAllElementsProof tests the range proof with all elements.
+// The edge proofs can be nil.
+func TestAllElementsProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ var k [][]byte
+ var v [][]byte
+ for i := 0; i < len(entries); i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // With edge proofs, it should still work.
+ proof := memorydb.New()
+ if err := trie.Prove(entries[0].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+
+ // Even with non-existent edge proofs, it should still work.
+ proof = memorydb.New()
+ first := types.StringToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
+ last := types.StringToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+}
+
+// TestSingleSideRangeProof tests the range starts from zero.
+func TestSingleSideRangeProof(t *testing.T) {
+ for i := 0; i < 64; i++ {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ var entries entrySlice
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ sort.Sort(entries)
+
+ var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+ for _, pos := range cases {
+ proof := memorydb.New()
+ if err := trie.Prove(types.Hash{}.Bytes(), 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := 0; i <= pos; i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), types.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+ }
+ }
+}
+
+// TestReverseSingleSideRangeProof tests the range ends with 0xffff...fff.
+func TestReverseSingleSideRangeProof(t *testing.T) {
+ for i := 0; i < 64; i++ {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ var entries entrySlice
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ sort.Sort(entries)
+
+ var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+ for _, pos := range cases {
+ proof := memorydb.New()
+ if err := trie.Prove(entries[pos].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ last := types.StringToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ if err := trie.Prove(last.Bytes(), 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := pos; i < len(entries); i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+ }
+ }
+}
+
+// TestBadRangeProof tests a few cases which the proof is wrong.
+// The prover is expected to detect the error.
+func TestBadRangeProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ for i := 0; i < 500; i++ {
+ start := mrand.Intn(len(entries))
+ end := mrand.Intn(len(entries)-start) + start + 1
+ proof := memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ var first, last = keys[0], keys[len(keys)-1]
+ testcase := mrand.Intn(6)
+ var index int
+ switch testcase {
+ case 0:
+ // Modified key
+ index = mrand.Intn(end - start)
+ keys[index] = randBytes(32) // In theory it can't be same
+ case 1:
+ // Modified val
+ index = mrand.Intn(end - start)
+ vals[index] = randBytes(20) // In theory it can't be same
+ case 2:
+ // Gapped entry slice
+ index = mrand.Intn(end - start)
+ if (index == 0 && start < 100) || (index == end-start-1 && end <= 100) {
+ continue
+ }
+ keys = append(keys[:index], keys[index+1:]...)
+ vals = append(vals[:index], vals[index+1:]...)
+ case 3:
+ // Out of order
+ index1 := mrand.Intn(end - start)
+ index2 := mrand.Intn(end - start)
+ if index1 == index2 {
+ continue
+ }
+ keys[index1], keys[index2] = keys[index2], keys[index1]
+ vals[index1], vals[index2] = vals[index2], vals[index1]
+ case 4:
+ // Set random key to nil, do nothing
+ index = mrand.Intn(end - start)
+ keys[index] = nil
+ case 5:
+ // Set random value to nil, deletion
+ index = mrand.Intn(end - start)
+ vals[index] = nil
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+ if err == nil {
+ t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
+ }
+ }
+}
+
+// TestGappedRangeProof focuses on the small trie with embedded nodes.
+// If the gapped node is embedded in the trie, it should be detected too.
+func TestGappedRangeProof(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ var entries []*kv // Sorted entries
+ for i := byte(0); i < 10; i++ {
+ value := &kv{types.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ first, last := 2, 8
+ proof := memorydb.New()
+ if err := trie.Prove(entries[first].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[last-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := first; i < last; i++ {
+ if i == (first+last)/2 {
+ continue
+ }
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ if err == nil {
+ t.Fatal("expect error, got nil")
+ }
+}
+
+// TestSameSideProofs tests the element is not in the range covered by proofs
+func TestSameSideProofs(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ pos := 1000
+ first := decreaseKey(types.CopyBytes(entries[pos].k))
+ first = decreaseKey(first)
+ last := decreaseKey(types.CopyBytes(entries[pos].k))
+
+ proof := memorydb.New()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ if err == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+
+ first = increaseKey(types.CopyBytes(entries[pos].k))
+ last = increaseKey(types.CopyBytes(entries[pos].k))
+ last = increaseKey(last)
+
+ proof = memorydb.New()
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(last, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+ if err == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestHasRightElement(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ var entries entrySlice
+ for i := 0; i < 4096; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ entries = append(entries, value)
+ }
+ sort.Sort(entries)
+
+ var cases = []struct {
+ start int
+ end int
+ hasMore bool
+ }{
+ {-1, 1, true}, // single element with non-existent left proof
+ {0, 1, true}, // single element with existent left proof
+ {0, 10, true},
+ {50, 100, true},
+ {50, len(entries), false}, // No more element expected
+ {len(entries) - 1, len(entries), false}, // Single last element with two existent proofs(point to same key)
+ {len(entries) - 1, -1, false}, // Single last element with non-existent right proof
+ {0, len(entries), false}, // The whole set with existent left proof
+ {-1, len(entries), false}, // The whole set with non-existent left proof
+ {-1, -1, false}, // The whole set with non-existent left/right proof
+ }
+ for _, c := range cases {
+ var (
+ firstKey []byte
+ lastKey []byte
+ start = c.start
+ end = c.end
+ proof = memorydb.New()
+ )
+ if c.start == -1 {
+ firstKey, start = types.Hash{}.Bytes(), 0
+ if err := trie.Prove(firstKey, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ } else {
+ firstKey = entries[c.start].k
+ if err := trie.Prove(entries[c.start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ }
+ if c.end == -1 {
+ lastKey, end = types.StringToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").Bytes(), len(entries)
+ if err := trie.Prove(lastKey, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ } else {
+ lastKey = entries[c.end-1].k
+ if err := trie.Prove(entries[c.end-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ }
+ k := make([][]byte, 0)
+ v := make([][]byte, 0)
+ for i := start; i < end; i++ {
+ k = append(k, entries[i].k)
+ v = append(v, entries[i].v)
+ }
+ hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
+ if err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+ if hasMore != c.hasMore {
+ t.Fatalf("Wrong hasMore indicator, want %t, got %t", c.hasMore, hasMore)
+ }
+ }
+}
+
+// TestEmptyRangeProof tests the range proof with "no" element.
+// The first edge proof must be a non-existent proof.
+func TestEmptyRangeProof(t *testing.T) {
+ trie, vals := randomTrie(4096)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ var cases = []struct {
+ pos int
+ err bool
+ }{
+ {len(entries) - 1, false},
+ {500, true},
+ }
+ for _, c := range cases {
+ proof := memorydb.New()
+ first := increaseKey(types.CopyBytes(entries[c.pos].k))
+ if err := trie.Prove(first, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
+ if c.err && err == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+ if !c.err && err != nil {
+ t.Fatalf("Expected no error, got %v", err)
+ }
+ }
+}
+
+// TestBloatedProof tests a malicious proof, where the proof is more or less the
+// whole trie. Previously we didn't accept such packets, but the new APIs do, so
+// lets leave this test as a bit weird, but present.
+func TestBloatedProof(t *testing.T) {
+ // Use a small trie
+ trie, kvs := nonRandomTrie(100)
+ var entries entrySlice
+ for _, kv := range kvs {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+ var keys [][]byte
+ var vals [][]byte
+
+ proof := memorydb.New()
+ // In the 'malicious' case, we add proofs for every single item
+ // (but only one key/value pair used as leaf)
+ for i, entry := range entries {
+ trie.Prove(entry.k, 0, proof)
+ if i == 50 {
+ keys = append(keys, entry.k)
+ vals = append(vals, entry.v)
+ }
+ }
+ // For reference, we use the same function, but _only_ prove the first
+ // and last element
+ want := memorydb.New()
+ trie.Prove(keys[0], 0, want)
+ trie.Prove(keys[len(keys)-1], 0, want)
+
+ if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
+ t.Fatalf("expected bloated proof to succeed, got %v", err)
+ }
+}
+
+// TestEmptyValueRangeProof tests normal range proof with both edge proofs
+// as the existent proof, but with an extra empty value included, which is a
+// noop technically, but practically should be rejected.
+func TestEmptyValueRangeProof(t *testing.T) {
+ trie, values := randomTrie(512)
+ var entries entrySlice
+ for _, kv := range values {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ // Create a new entry with a slightly modified key
+ mid := len(entries) / 2
+ key := types.CopyBytes(entries[mid-1].k)
+ for n := len(key) - 1; n >= 0; n-- {
+ if key[n] < 0xff {
+ key[n]++
+ break
+ }
+ }
+ noop := &kv{key, []byte{}, false}
+ entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
+
+ start, end := 1, len(entries)-1
+
+ proof := memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ t.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var vals [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+ if err == nil {
+ t.Fatalf("Expected failure on noop entry")
+ }
+}
+
+// TestAllElementsEmptyValueRangeProof tests the range proof with all elements,
+// but with an extra empty value included, which is a noop technically, but
+// practically should be rejected.
+func TestAllElementsEmptyValueRangeProof(t *testing.T) {
+ trie, values := randomTrie(512)
+ var entries entrySlice
+ for _, kv := range values {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ // Create a new entry with a slightly modified key
+ mid := len(entries) / 2
+ key := types.CopyBytes(entries[mid-1].k)
+ for n := len(key) - 1; n >= 0; n-- {
+ if key[n] < 0xff {
+ key[n]++
+ break
+ }
+ }
+ noop := &kv{key, []byte{}, false}
+ entries = append(append(append([]*kv{}, entries[:mid]...), noop), entries[mid:]...)
+
+ var keys [][]byte
+ var vals [][]byte
+ for i := 0; i < len(entries); i++ {
+ keys = append(keys, entries[i].k)
+ vals = append(vals, entries[i].v)
+ }
+ _, err := VerifyRangeProof(trie.Hash(), nil, nil, keys, vals, nil)
+ if err == nil {
+ t.Fatalf("Expected failure on noop entry")
+ }
+}
+
+// mutateByte changes one byte in b.
+func mutateByte(b []byte) {
+ for r := mrand.Intn(len(b)); ; {
+ newB := byte(mrand.Intn(255))
+ if newB != b[r] {
+ b[r] = newB
+ break
+ }
+ }
+}
+
+func increaseKey(key []byte) []byte {
+ for i := len(key) - 1; i >= 0; i-- {
+ key[i]++
+ if key[i] != 0x0 {
+ break
+ }
+ }
+ return key
+}
+
+func decreaseKey(key []byte) []byte {
+ for i := len(key) - 1; i >= 0; i-- {
+ key[i]--
+ if key[i] != 0xff {
+ break
+ }
+ }
+ return key
+}
+
+func BenchmarkProve(b *testing.B) {
+ trie, vals := randomTrie(100)
+ var keys []string
+ for k := range vals {
+ keys = append(keys, k)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ kv := vals[keys[i%len(keys)]]
+ proofs := memorydb.New()
+ if trie.Prove(kv.k, 0, proofs); proofs.Len() == 0 {
+ b.Fatalf("zero length proof for %x", kv.k)
+ }
+ }
+}
+
+func BenchmarkVerifyProof(b *testing.B) {
+ trie, vals := randomTrie(100)
+ root := trie.Hash()
+ var keys []string
+ var proofs []*memorydb.Database
+ for k := range vals {
+ keys = append(keys, k)
+ proof := memorydb.New()
+ trie.Prove([]byte(k), 0, proof)
+ proofs = append(proofs, proof)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ im := i % len(keys)
+ if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
+ b.Fatalf("key %x: %v", keys[im], err)
+ }
+ }
+}
+
+func BenchmarkVerifyRangeProof10(b *testing.B) { benchmarkVerifyRangeProof(b, 10) }
+func BenchmarkVerifyRangeProof100(b *testing.B) { benchmarkVerifyRangeProof(b, 100) }
+func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, 1000) }
+func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) }
+
+func benchmarkVerifyRangeProof(b *testing.B, size int) {
+ b.Helper()
+
+ trie, vals := randomTrie(8192)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ start := 2
+ end := start + size
+ proof := memorydb.New()
+ if err := trie.Prove(entries[start].k, 0, proof); err != nil {
+ b.Fatalf("Failed to prove the first node %v", err)
+ }
+ if err := trie.Prove(entries[end-1].k, 0, proof); err != nil {
+ b.Fatalf("Failed to prove the last node %v", err)
+ }
+ var keys [][]byte
+ var values [][]byte
+ for i := start; i < end; i++ {
+ keys = append(keys, entries[i].k)
+ values = append(values, entries[i].v)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
+ if err != nil {
+ b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+ }
+ }
+}
+
+func BenchmarkVerifyRangeNoProof10(b *testing.B) { benchmarkVerifyRangeNoProof(b, 100) }
+func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof(b, 500) }
+func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
+
+func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
+ b.Helper()
+
+ trie, vals := randomTrie(size)
+ var entries entrySlice
+ for _, kv := range vals {
+ entries = append(entries, kv)
+ }
+ sort.Sort(entries)
+
+ var keys [][]byte
+ var values [][]byte
+ for _, entry := range entries {
+ keys = append(keys, entry.k)
+ values = append(values, entry.v)
+ }
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
+ if err != nil {
+ b.Fatalf("Expected no error, got %v", err)
+ }
+ }
+}
+
+func randomTrie(n int) (*Trie, map[string]*kv) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ vals := make(map[string]*kv)
+ for i := byte(0); i < 100; i++ {
+ value := &kv{types.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ value2 := &kv{types.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
+ trie.Update(value.k, value.v)
+ trie.Update(value2.k, value2.v)
+ vals[string(value.k)] = value
+ vals[string(value2.k)] = value2
+ }
+ for i := 0; i < n; i++ {
+ value := &kv{randBytes(32), randBytes(20), false}
+ trie.Update(value.k, value.v)
+ vals[string(value.k)] = value
+ }
+ return trie, vals
+}
+
+func randBytes(n int) []byte {
+ r := make([]byte, n)
+ crand.Read(r)
+ return r
+}
+
+func nonRandomTrie(n int) (*Trie, map[string]*kv) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ vals := make(map[string]*kv)
+ max := uint64(0xffffffffffffffff)
+ for i := uint64(0); i < uint64(n); i++ {
+ value := make([]byte, 32)
+ key := make([]byte, 32)
+ binary.LittleEndian.PutUint64(key, i)
+ binary.LittleEndian.PutUint64(value, i-max)
+ //value := &kv{types.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+ elem := &kv{key, value, false}
+ trie.Update(elem.k, elem.v)
+ vals[string(elem.k)] = elem
+ }
+ return trie, vals
+}
+
+func TestRangeProofKeysWithSharedPrefix(t *testing.T) {
+ keys := [][]byte{
+ types.StringToBytes("aa10000000000000000000000000000000000000000000000000000000000000"),
+ types.StringToBytes("aa20000000000000000000000000000000000000000000000000000000000000"),
+ }
+ vals := [][]byte{
+ types.StringToBytes("02"),
+ types.StringToBytes("03"),
+ }
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for i, key := range keys {
+ trie.Update(key, vals[i])
+ }
+ root := trie.Hash()
+ proof := memorydb.New()
+ start := types.StringToBytes("0000000000000000000000000000000000000000000000000000000000000000")
+ end := types.StringToBytes("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
+ if err := trie.Prove(start, 0, proof); err != nil {
+ t.Fatalf("failed to prove start: %v", err)
+ }
+ if err := trie.Prove(end, 0, proof); err != nil {
+ t.Fatalf("failed to prove end: %v", err)
+ }
+
+ more, err := VerifyRangeProof(root, start, end, keys, vals, proof)
+ if err != nil {
+ t.Fatalf("failed to verify range proof: %v", err)
+ }
+ if more != false {
+ t.Error("expected more to be false")
+ }
+}
diff --git a/trie/schema.go b/trie/schema.go
new file mode 100644
index 0000000000..9c1c207683
--- /dev/null
+++ b/trie/schema.go
@@ -0,0 +1,97 @@
+// Copyright 2021 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+const (
+ HashScheme = "hashScheme" // Identifier of hash based node scheme
+
+ // Path-based scheme will be introduced in the following PRs.
+ // PathScheme = "pathScheme" // Identifier of path based node scheme
+)
+
+// NodeScheme describes the scheme for interacting nodes in disk.
+type NodeScheme interface {
+ // Name returns the identifier of node scheme.
+ Name() string
+
+ // HasTrieNode checks the trie node presence with the provided node info and
+ // the associated node hash.
+ HasTrieNode(db kvdb.KVReader, owner types.Hash, path []byte, hash types.Hash) bool
+
+ // ReadTrieNode retrieves the trie node from database with the provided node
+ // info and the associated node hash.
+ ReadTrieNode(db kvdb.KVReader, owner types.Hash, path []byte, hash types.Hash) []byte
+
+ // WriteTrieNode writes the trie node into database with the provided node
+ // info and associated node hash.
+ WriteTrieNode(db kvdb.KVWriter, owner types.Hash, path []byte, hash types.Hash, node []byte)
+
+ // DeleteTrieNode deletes the trie node from database with the provided node
+ // info and associated node hash.
+ DeleteTrieNode(db kvdb.KVWriter, owner types.Hash, path []byte, hash types.Hash)
+
+ // IsTrieNode returns an indicator if the given database key is the key of
+ // trie node according to the scheme.
+ IsTrieNode(key []byte) (bool, []byte)
+}
+
+type hashScheme struct{}
+
+// Name returns the identifier of hash based scheme.
+func (scheme *hashScheme) Name() string {
+ return HashScheme
+}
+
+// HasTrieNode checks the trie node presence with the provided node info and
+// the associated node hash.
+func (scheme *hashScheme) HasTrieNode(db kvdb.KVReader, owner types.Hash, path []byte, hash types.Hash) bool {
+ return rawdb.HasTrieNode(db, hash)
+}
+
+// ReadTrieNode retrieves the trie node from database with the provided node info
+// and associated node hash.
+func (scheme *hashScheme) ReadTrieNode(db kvdb.KVReader, owner types.Hash, path []byte, hash types.Hash) []byte {
+ return rawdb.ReadTrieNode(db, hash)
+}
+
+// WriteTrieNode writes the trie node into database with the provided node info
+// and associated node hash.
+func (scheme *hashScheme) WriteTrieNode(db kvdb.KVWriter, owner types.Hash, path []byte, hash types.Hash, node []byte) {
+ rawdb.WriteTrieNode(db, hash, node)
+}
+
+// DeleteTrieNode deletes the trie node from database with the provided node info
+// and associated node hash.
+func (scheme *hashScheme) DeleteTrieNode(db kvdb.KVWriter, owner types.Hash, path []byte, hash types.Hash) {
+ rawdb.DeleteTrieNode(db, hash)
+}
+
+// IsTrieNode returns an indicator if the given database key is the key of trie
+// node according to the scheme.
+func (scheme *hashScheme) IsTrieNode(key []byte) (bool, []byte) {
+ if len(key) == types.HashLength {
+ return true, key
+ }
+
+ return false, nil
+}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
new file mode 100644
index 0000000000..833a980267
--- /dev/null
+++ b/trie/secure_trie.go
@@ -0,0 +1,289 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/dogechain-lab/fastrlp"
+)
+
+// SecureTrie is the old name of StateTrie.
+// Deprecated: use StateTrie.
+type SecureTrie = StateTrie
+
+// NewSecure creates a new StateTrie.
+// Deprecated: use NewStateTrie.
+func NewSecure(
+ stateRoot types.Hash,
+ owner types.Hash,
+ root types.Hash,
+ db *Database,
+ logger Logger,
+) (*SecureTrie, error) {
+ id := &ID{
+ StateRoot: stateRoot,
+ Owner: owner,
+ Root: root,
+ }
+
+ return NewStateTrie(id, db, logger)
+}
+
+// StateTrie wraps a trie with key hashing. In a stateTrie trie, all
+// access operations hash the key using keccak256. This prevents
+// calling code from creating long chains of nodes that
+// increase the access time.
+//
+// Contrary to a regular trie, a StateTrie can only be created with
+// New and must have an attached database. The database also stores
+// the preimage of each key if preimage recording is enabled.
+//
+// StateTrie is not safe for concurrent use.
+type StateTrie struct {
+ trie Trie
+ hashKeyBuf [types.HashLength]byte
+ secKeyCache map[string][]byte
+ secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
+ logger Logger
+}
+
+// NewStateTrie creates a trie with an existing root node from a backing database.
+//
+// If root is the zero hash or the sha3 hash of an empty string, the
+// trie is initially empty. Otherwise, New will panic if db is nil
+// and returns MissingNodeError if the root node cannot be found.
+func NewStateTrie(id *ID, db *Database, logger Logger) (*StateTrie, error) {
+ if db == nil {
+ panic("trie.NewStateTrie called without a database")
+ }
+
+ trie, err := New(id, db, logger)
+ if err != nil {
+ return nil, err
+ }
+
+ return &StateTrie{trie: *trie}, nil
+}
+
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *StateTrie) Get(key []byte) []byte {
+ res, err := t.TryGet(key)
+ if err != nil {
+ t.logger.Error("Unhandled trie error in StateTrie.Get", "err", err)
+ }
+
+ return res
+}
+
+// TryGet returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+// If the specified node is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
+ return t.trie.TryGet(t.hashKey(key))
+}
+
+// TryGetAccount attempts to retrieve an account with provided trie path.
+// If the specified account is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryGetAccount(key []byte) (*stypes.Account, error) {
+ res, err := t.trie.TryGet(t.hashKey(key))
+ if res == nil || err != nil {
+ return nil, err
+ }
+
+ ret := new(stypes.Account)
+ err = ret.UnmarshalRlp(res)
+
+ return ret, err
+}
+
+// TryGetAccountWithPreHashedKey does the same thing as TryGetAccount, however
+// it expects a key that is already hashed. This constitutes an abstraction leak,
+// since the client code needs to know the key format.
+func (t *StateTrie) TryGetAccountWithPreHashedKey(key []byte) (*stypes.Account, error) {
+ res, err := t.trie.TryGet(key)
+ if res == nil || err != nil {
+ return nil, err
+ }
+
+ ret := new(stypes.Account)
+ err = ret.UnmarshalRlp(res)
+
+ return ret, err
+}
+
+// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
+// possible to use keybyte-encoding as the path might contain odd nibbles.
+// If the specified trie node is not in the trie, nil will be returned.
+// If a trie node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
+ return t.trie.TryGetNode(path)
+}
+
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *StateTrie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil {
+ t.logger.Error("Unhandled trie error in StateTrie.Update", "err", err)
+ }
+}
+
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryUpdate(key, value []byte) error {
+ hk := t.hashKey(key)
+
+ err := t.trie.TryUpdate(hk, value)
+ if err != nil {
+ return err
+ }
+
+ t.getSecKeyCache()[string(hk)] = types.CopyBytes(key)
+
+ return nil
+}
+
+var accountArenaPool fastrlp.ArenaPool
+
+// TryUpdateAccount account will abstract the write of an account to the
+// secure trie.
+func (t *StateTrie) TryUpdateAccount(key []byte, acc *stypes.Account) error {
+ ar := accountArenaPool.Get()
+ defer accountArenaPool.Put(ar)
+
+ vv := acc.MarshalWith(ar)
+ data := vv.MarshalTo(nil)
+
+ hk := t.hashKey(key)
+
+ if err := t.trie.TryUpdate(hk, data); err != nil {
+ return err
+ }
+
+ t.getSecKeyCache()[string(hk)] = types.CopyBytes(key)
+
+ return nil
+}
+
+// Delete removes any existing value for key from the trie.
+func (t *StateTrie) Delete(key []byte) {
+ if err := t.TryDelete(key); err != nil {
+ t.logger.Error("Unhandled trie error in StateTrie.Delete", "err", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If the specified trie node is not in the trie, nothing will be changed.
+// If a node is not found in the database, a MissingNodeError is returned.
+func (t *StateTrie) TryDelete(key []byte) error {
+ hk := t.hashKey(key)
+ delete(t.getSecKeyCache(), string(hk))
+
+ return t.trie.TryDelete(hk)
+}
+
+// TryDeleteAccount abstracts an account deletion from the trie.
+func (t *StateTrie) TryDeleteAccount(key []byte) error {
+ hk := t.hashKey(key)
+ delete(t.getSecKeyCache(), string(hk))
+
+ return t.trie.TryDelete(hk)
+}
+
+// GetKey returns the sha3 preimage of a hashed key that was
+// previously used to store a value.
+func (t *StateTrie) GetKey(shaKey []byte) []byte {
+ if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
+ return key
+ }
+
+ return nil
+}
+
+// Commit collects all dirty nodes in the trie and replaces them with the
+// corresponding node hash. All collected nodes (including dirty leaves if
+// collectLeaf is true) will be encapsulated into a nodeset for return.
+// The returned nodeset can be nil if the trie is clean (nothing to commit).
+// All cached preimages will be also flushed if preimages recording is enabled.
+// Once the trie is committed, it's not usable anymore. A new trie must
+// be created with new root and updated trie database for following usage
+func (t *StateTrie) Commit(collectLeaf bool) (types.Hash, *NodeSet, error) {
+ // Write all the pre-images to the actual disk database
+ if len(t.getSecKeyCache()) > 0 {
+ t.secKeyCache = make(map[string][]byte)
+ }
+ // Commit the trie and return its modified nodeset.
+ return t.trie.Commit(collectLeaf)
+}
+
+// Hash returns the root hash of StateTrie. It does not write to the
+// database and can be used even if the trie doesn't have one.
+func (t *StateTrie) Hash() types.Hash {
+ return t.trie.Hash()
+}
+
+// Copy returns a copy of StateTrie.
+func (t *StateTrie) Copy() *StateTrie {
+ return &StateTrie{
+ trie: *t.trie.Copy(),
+ secKeyCache: t.secKeyCache,
+ }
+}
+
+// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
+// starts at the key after the given start key.
+func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
+ return t.trie.NodeIterator(start)
+}
+
+// hashKey returns the hash of key as an ephemeral buffer.
+// The caller must not hold onto the return value because it will become
+// invalid on the next call to hashKey or secKey.
+func (t *StateTrie) hashKey(key []byte) []byte {
+ h := newHasher(false)
+ h.sha.Reset()
+ h.sha.Write(key)
+ h.sha.Read(t.hashKeyBuf[:])
+ returnHasherToPool(h)
+
+ return t.hashKeyBuf[:]
+}
+
+// getSecKeyCache returns the current secure key cache, creating a new one if
+// ownership changed (i.e. the current secure trie is a copy of another owning
+// the actual cache).
+func (t *StateTrie) getSecKeyCache() map[string][]byte {
+ if t != t.secKeyCacheOwner {
+ t.secKeyCacheOwner = t
+ t.secKeyCache = make(map[string][]byte)
+ }
+
+ return t.secKeyCache
+}
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
new file mode 100644
index 0000000000..8d38369704
--- /dev/null
+++ b/trie/secure_trie_test.go
@@ -0,0 +1,153 @@
+// Copyright 2015 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "fmt"
+ "runtime"
+ "sync"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+func newEmptySecure() *StateTrie {
+ logger := hclog.NewNullLogger()
+ trie, _ := NewStateTrie(TrieID(types.Hash{}), NewDatabase(rawdb.NewMemoryDatabase(), logger), logger)
+ return trie
+}
+
+// makeTestStateTrie creates a large enough secure trie for testing.
+func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
+ // Create an empty trie
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie, _ := NewStateTrie(TrieID(types.Hash{}), triedb, logger)
+
+ // Fill it with some arbitrary data
+ content := make(map[string][]byte)
+ for i := byte(0); i < 255; i++ {
+ // Map the same data under multiple keys
+ key, val := types.LeftPadBytes([]byte{1, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ key, val = types.LeftPadBytes([]byte{2, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ // Add some other data to inflate the trie
+ for j := byte(3); j < 13; j++ {
+ key, val = types.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
+ content[string(key)] = val
+ trie.Update(key, val)
+ }
+ }
+ root, nodes, err := trie.Commit(false)
+ if err != nil {
+ panic(fmt.Errorf("failed to commit trie %v", err))
+ }
+ if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
+ panic(fmt.Errorf("failed to commit db %v", err))
+ }
+ // Re-create the trie based on the new state
+ trie, _ = NewStateTrie(TrieID(root), triedb, logger)
+ return triedb, trie, content
+}
+
+func TestSecureDelete(t *testing.T) {
+ trie := newEmptySecure()
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ trie.Update([]byte(val.k), []byte(val.v))
+ } else {
+ trie.Delete([]byte(val.k))
+ }
+ }
+ hash := trie.Hash()
+ exp := types.StringToHash("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestSecureGetKey(t *testing.T) {
+ trie := newEmptySecure()
+ trie.Update([]byte("foo"), []byte("bar"))
+
+ key := []byte("foo")
+ value := []byte("bar")
+ seckey := crypto.Keccak256(key)
+
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Errorf("Get did not return bar")
+ }
+ if k := trie.GetKey(seckey); !bytes.Equal(k, key) {
+ t.Errorf("GetKey returned %q, want %q", k, key)
+ }
+}
+
+func TestStateTrieConcurrency(t *testing.T) {
+ // Create an initial trie and copy if for concurrent access
+ _, trie, _ := makeTestStateTrie()
+
+ threads := runtime.NumCPU()
+ tries := make([]*StateTrie, threads)
+ for i := 0; i < threads; i++ {
+ tries[i] = trie.Copy()
+ }
+ // Start a batch of goroutines interacting with the trie
+ pend := new(sync.WaitGroup)
+ pend.Add(threads)
+ for i := 0; i < threads; i++ {
+ go func(index int) {
+ defer pend.Done()
+
+ for j := byte(0); j < 255; j++ {
+ // Map the same data under multiple keys
+ key, val := types.LeftPadBytes([]byte{byte(index), 1, j}, 32), []byte{j}
+ tries[index].Update(key, val)
+
+ key, val = types.LeftPadBytes([]byte{byte(index), 2, j}, 32), []byte{j}
+ tries[index].Update(key, val)
+
+ // Add some other data to inflate the trie
+ for k := byte(3); k < 13; k++ {
+ key, val = types.LeftPadBytes([]byte{byte(index), k, j}, 32), []byte{k, j}
+ tries[index].Update(key, val)
+ }
+ }
+ tries[index].Commit(false)
+ }(i)
+ }
+ // Wait for all threads to finish
+ pend.Wait()
+}
diff --git a/trie/stacktrie.go b/trie/stacktrie.go
new file mode 100644
index 0000000000..bbc8ba96ae
--- /dev/null
+++ b/trie/stacktrie.go
@@ -0,0 +1,563 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/gob"
+ "errors"
+ "io"
+ "log"
+ "sync"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+var ErrCommitDisabled = errors.New("no database for committing")
+
+var stPool = sync.Pool{
+ New: func() interface{} {
+ return NewStackTrie(nil)
+ },
+}
+
+// NodeWriteFunc is used to provide all information of a dirty node for committing
+// so that callers can flush nodes into database with desired scheme.
+type NodeWriteFunc = func(owner types.Hash, path []byte, hash types.Hash, blob []byte)
+
+func stackTrieFromPool(writeFn NodeWriteFunc, owner types.Hash) *StackTrie {
+ st, _ := stPool.Get().(*StackTrie)
+ st.owner = owner
+ st.writeFn = writeFn
+
+ return st
+}
+
+func returnToPool(st *StackTrie) {
+ st.Reset()
+ stPool.Put(st)
+}
+
+// StackTrie is a trie implementation that expects keys to be inserted
+// in order. Once it determines that a subtree will no longer be inserted
+// into, it will hash it and free up the memory it uses.
+type StackTrie struct {
+ owner types.Hash // the owner of the trie
+ nodeType uint8 // node type (as in branch, ext, leaf)
+ val []byte // value contained by this node if it's a leaf
+ key []byte // key chunk covered by this (leaf|ext) node
+ children [16]*StackTrie // list of children (for branch and exts)
+ writeFn NodeWriteFunc // function for committing nodes, can be nil
+}
+
+// NewStackTrie allocates and initializes an empty trie.
+func NewStackTrie(writeFn NodeWriteFunc) *StackTrie {
+ return &StackTrie{
+ nodeType: emptyNode,
+ writeFn: writeFn,
+ }
+}
+
+// NewStackTrieWithOwner allocates and initializes an empty trie, but with
+// the additional owner field.
+func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner types.Hash) *StackTrie {
+ return &StackTrie{
+ owner: owner,
+ nodeType: emptyNode,
+ writeFn: writeFn,
+ }
+}
+
+// NewFromBinary initialises a serialized stacktrie with the given db.
+func NewFromBinary(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) {
+ var st StackTrie
+ if err := st.UnmarshalBinary(data); err != nil {
+ return nil, err
+ }
+
+ // If a database is used, we need to recursively add it to every child
+ if writeFn != nil {
+ st.setWriter(writeFn)
+ }
+
+ return &st, nil
+}
+
+// MarshalBinary implements encoding.BinaryMarshaler
+func (st *StackTrie) MarshalBinary() (data []byte, err error) {
+ var (
+ b bytes.Buffer
+ w = bufio.NewWriter(&b)
+ )
+
+ if err := gob.NewEncoder(w).Encode(struct {
+ Owner types.Hash
+ NodeType uint8
+ Val []byte
+ Key []byte
+ }{
+ st.owner,
+ st.nodeType,
+ st.val,
+ st.key,
+ }); err != nil {
+ return nil, err
+ }
+
+ for _, child := range st.children {
+ if child == nil {
+ w.WriteByte(0)
+
+ continue
+ }
+
+ w.WriteByte(1)
+
+ if childData, err := child.MarshalBinary(); err != nil {
+ return nil, err
+ } else {
+ w.Write(childData)
+ }
+ }
+
+ w.Flush()
+
+ return b.Bytes(), nil
+}
+
+// UnmarshalBinary implements encoding.BinaryUnmarshaler
+func (st *StackTrie) UnmarshalBinary(data []byte) error {
+ r := bytes.NewReader(data)
+
+ return st.unmarshalBinary(r)
+}
+
+func (st *StackTrie) unmarshalBinary(r io.Reader) error {
+ var dec struct {
+ Owner types.Hash
+ NodeType uint8
+ Val []byte
+ Key []byte
+ }
+
+ gob.NewDecoder(r).Decode(&dec)
+ st.owner = dec.Owner
+ st.nodeType = dec.NodeType
+ st.val = dec.Val
+ st.key = dec.Key
+
+ var hasChild = make([]byte, 1)
+
+ for i := range st.children {
+ if _, err := r.Read(hasChild); err != nil {
+ return err
+ } else if hasChild[0] == 0 {
+ continue
+ }
+
+ var child StackTrie
+
+ child.unmarshalBinary(r)
+ st.children[i] = &child
+ }
+
+ return nil
+}
+
+func (st *StackTrie) setWriter(writeFn NodeWriteFunc) {
+ st.writeFn = writeFn
+ for _, child := range st.children {
+ if child != nil {
+ child.setWriter(writeFn)
+ }
+ }
+}
+
+func newLeaf(owner types.Hash, key, val []byte, writeFn NodeWriteFunc) *StackTrie {
+ st := stackTrieFromPool(writeFn, owner)
+ st.nodeType = leafNode
+ st.key = append(st.key, key...)
+ st.val = val
+
+ return st
+}
+
+func newExt(owner types.Hash, key []byte, child *StackTrie, writeFn NodeWriteFunc) *StackTrie {
+ st := stackTrieFromPool(writeFn, owner)
+ st.nodeType = extNode
+ st.key = append(st.key, key...)
+ st.children[0] = child
+
+ return st
+}
+
+// List all values that StackTrie#nodeType can hold
+const (
+ emptyNode = iota
+ branchNode
+ extNode
+ leafNode
+ hashedNode
+)
+
+// TryUpdate inserts a (key, value) pair into the stack trie
+func (st *StackTrie) TryUpdate(key, value []byte) error {
+ k := keybytesToHex(key)
+
+ if len(value) == 0 {
+ panic("deletion not supported")
+ }
+
+ st.insert(k[:len(k)-1], value, nil)
+
+ return nil
+}
+
+func (st *StackTrie) Update(key, value []byte) {
+ if err := st.TryUpdate(key, value); err != nil {
+ log.Printf("Unhandled trie error in StackTrie.Update: %v\n", err.Error())
+ }
+}
+
+func (st *StackTrie) Reset() {
+ st.owner = types.Hash{}
+ st.writeFn = nil
+ st.key = st.key[:0]
+ st.val = nil
+
+ for i := range st.children {
+ st.children[i] = nil
+ }
+
+ st.nodeType = emptyNode
+}
+
+// Helper function that, given a full key, determines the index
+// at which the chunk pointed by st.keyOffset is different from
+// the same chunk in the full key.
+func (st *StackTrie) getDiffIndex(key []byte) int {
+ for idx, nibble := range st.key {
+ if nibble != key[idx] {
+ return idx
+ }
+ }
+
+ return len(st.key)
+}
+
+// Helper function to that inserts a (key, value) pair into
+// the trie.
+func (st *StackTrie) insert(key, value []byte, prefix []byte) {
+ switch st.nodeType {
+ case branchNode: /* Branch */
+ idx := int(key[0])
+
+ // Unresolve elder siblings
+ for i := idx - 1; i >= 0; i-- {
+ if st.children[i] != nil {
+ if st.children[i].nodeType != hashedNode {
+ st.children[i].hash(append(prefix, byte(i)))
+ }
+
+ break
+ }
+ }
+
+ // Add new child
+ if st.children[idx] == nil {
+ st.children[idx] = newLeaf(st.owner, key[1:], value, st.writeFn)
+ } else {
+ st.children[idx].insert(key[1:], value, append(prefix, key[0]))
+ }
+ case extNode: /* Ext */
+ // Compare both key chunks and see where they differ
+ diffidx := st.getDiffIndex(key)
+
+ // Check if chunks are identical. If so, recurse into
+ // the child node. Otherwise, the key has to be split
+ // into 1) an optional common prefix, 2) the fullnode
+ // representing the two differing path, and 3) a leaf
+ // for each of the differentiated subtrees.
+ if diffidx == len(st.key) {
+ // Ext key and key segment are identical, recurse into
+ // the child node.
+ st.children[0].insert(key[diffidx:], value, append(prefix, key[:diffidx]...))
+
+ return
+ }
+ // Save the original part. Depending if the break is
+ // at the extension's last byte or not, create an
+ // intermediate extension or use the extension's child
+ // node directly.
+ var n *StackTrie
+ if diffidx < len(st.key)-1 {
+ // Break on the non-last byte, insert an intermediate
+ // extension. The path prefix of the newly-inserted
+ // extension should also contain the different byte.
+ n = newExt(st.owner, st.key[diffidx+1:], st.children[0], st.writeFn)
+ n.hash(append(prefix, st.key[:diffidx+1]...))
+ } else {
+ // Break on the last byte, no need to insert
+ // an extension node: reuse the current node.
+ // The path prefix of the original part should
+ // still be same.
+ n = st.children[0]
+ n.hash(append(prefix, st.key...))
+ }
+
+ var p *StackTrie
+
+ if diffidx == 0 {
+ // the break is on the first byte, so
+ // the current node is converted into
+ // a branch node.
+ st.children[0] = nil
+ p = st
+ st.nodeType = branchNode
+ } else {
+ // the common prefix is at least one byte
+ // long, insert a new intermediate branch
+ // node.
+ st.children[0] = stackTrieFromPool(st.writeFn, st.owner)
+ st.children[0].nodeType = branchNode
+ p = st.children[0]
+ }
+
+ // Create a leaf for the inserted part
+ o := newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
+
+ // Insert both child leaves where they belong:
+ origIdx := st.key[diffidx]
+ newIdx := key[diffidx]
+ p.children[origIdx] = n
+ p.children[newIdx] = o
+ st.key = st.key[:diffidx]
+ case leafNode: /* Leaf */
+ // Compare both key chunks and see where they differ
+ diffidx := st.getDiffIndex(key)
+
+ // Overwriting a key isn't supported, which means that
+ // the current leaf is expected to be split into 1) an
+ // optional extension for the common prefix of these 2
+ // keys, 2) a fullnode selecting the path on which the
+ // keys differ, and 3) one leaf for the differentiated
+ // component of each key.
+ if diffidx >= len(st.key) {
+ panic("Trying to insert into existing key")
+ }
+
+ // Check if the split occurs at the first nibble of the
+ // chunk. In that case, no prefix extnode is necessary.
+ // Otherwise, create that
+ var p *StackTrie
+
+ if diffidx == 0 {
+ // Convert current leaf into a branch
+ st.nodeType = branchNode
+ p = st
+ st.children[0] = nil
+ } else {
+ // Convert current node into an ext,
+ // and insert a child branch node.
+ st.nodeType = extNode
+ st.children[0] = NewStackTrieWithOwner(st.writeFn, st.owner)
+ st.children[0].nodeType = branchNode
+ p = st.children[0]
+ }
+
+ // Create the two child leaves: one containing the original
+ // value and another containing the new value. The child leaf
+ // is hashed directly in order to free up some memory.
+ origIdx := st.key[diffidx]
+ p.children[origIdx] = newLeaf(st.owner, st.key[diffidx+1:], st.val, st.writeFn)
+ p.children[origIdx].hash(append(prefix, st.key[:diffidx+1]...))
+
+ newIdx := key[diffidx]
+ p.children[newIdx] = newLeaf(st.owner, key[diffidx+1:], value, st.writeFn)
+
+ // Finally, cut off the key part that has been passed
+ // over to the children.
+ st.key = st.key[:diffidx]
+ st.val = nil
+ case emptyNode: /* Empty */
+ st.nodeType = leafNode
+ st.key = key
+ st.val = value
+ case hashedNode:
+ panic("trying to insert into hash")
+ default:
+ panic("invalid type")
+ }
+}
+
+// hash converts st into a 'hashedNode', if possible. Possible outcomes:
+//
+// 1. The rlp-encoded value was >= 32 bytes:
+// - Then the 32-byte `hash` will be accessible in `st.val`.
+// - And the 'st.type' will be 'hashedNode'
+//
+// 2. The rlp-encoded value was < 32 bytes
+// - Then the <32 byte rlp-encoded value will be accessible in 'st.val'.
+// - And the 'st.type' will be 'hashedNode' AGAIN
+//
+// This method also sets 'st.type' to hashedNode, and clears 'st.key'.
+func (st *StackTrie) hash(path []byte) {
+ h := newHasher(false)
+ defer returnHasherToPool(h)
+
+ st.hashRec(h, path)
+}
+
+func (st *StackTrie) hashRec(hasher *hasher, path []byte) {
+ // The switch below sets this to the RLP-encoding of this node.
+ var encodedNode []byte
+
+ switch st.nodeType {
+ case hashedNode:
+ return
+ case emptyNode:
+ st.val = types.EmptyRootHash.Bytes()
+ st.key = st.key[:0]
+ st.nodeType = hashedNode
+
+ return
+ case branchNode:
+ var nodes rawFullNode
+
+ for i, child := range st.children {
+ if child == nil {
+ nodes[i] = nilValueNode
+
+ continue
+ }
+
+ child.hashRec(hasher, append(path, byte(i)))
+
+ if len(child.val) < 32 {
+ nodes[i] = rawNode(child.val)
+ } else {
+ nodes[i] = hashNode(child.val)
+ }
+
+ // Release child back to pool.
+ st.children[i] = nil
+
+ returnToPool(child)
+ }
+
+ nodes.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+ case extNode:
+ st.children[0].hashRec(hasher, append(path, st.key...))
+
+ n := rawShortNode{Key: hexToCompact(st.key)}
+
+ if len(st.children[0].val) < 32 {
+ n.Val = rawNode(st.children[0].val)
+ } else {
+ n.Val = hashNode(st.children[0].val)
+ }
+
+ n.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+
+ // Release child back to pool.
+ returnToPool(st.children[0])
+ st.children[0] = nil
+ case leafNode:
+ st.key = append(st.key, byte(16))
+ n := rawShortNode{Key: hexToCompact(st.key), Val: valueNode(st.val)}
+
+ n.encode(hasher.encbuf)
+ encodedNode = hasher.encodedBytes()
+ default:
+ panic("invalid node type")
+ }
+
+ st.nodeType = hashedNode
+ st.key = st.key[:0]
+
+ if len(encodedNode) < 32 {
+ st.val = types.CopyBytes(encodedNode)
+
+ return
+ }
+
+ // Write the hash to the 'val'. We allocate a new val here to not mutate
+ // input values
+ st.val = hasher.hashData(encodedNode)
+ if st.writeFn != nil {
+ st.writeFn(st.owner, path, types.BytesToHash(st.val), encodedNode)
+ }
+}
+
+// Hash returns the hash of the current node.
+func (st *StackTrie) Hash() (h types.Hash) {
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ st.hashRec(hasher, nil)
+
+ if len(st.val) == 32 {
+ copy(h[:], st.val)
+
+ return h
+ }
+ // If the node's RLP isn't 32 bytes long, the node will not
+ // be hashed, and instead contain the rlp-encoding of the
+ // node. For the top level node, we need to force the hashing.
+ hasher.sha.Reset()
+ hasher.sha.Write(st.val)
+ hasher.sha.Read(h[:])
+
+ return h
+}
+
+// Commit will firstly hash the entire trie if it's still not hashed
+// and then commit all nodes to the associated database. Actually most
+// of the trie nodes MAY have been committed already. The main purpose
+// here is to commit the root node.
+//
+// The associated database is expected, otherwise the whole commit
+// functionality should be disabled.
+func (st *StackTrie) Commit() (h types.Hash, err error) {
+ if st.writeFn == nil {
+ return types.Hash{}, ErrCommitDisabled
+ }
+
+ hasher := newHasher(false)
+ defer returnHasherToPool(hasher)
+
+ st.hashRec(hasher, nil)
+
+ if len(st.val) == 32 {
+ copy(h[:], st.val)
+
+ return h, nil
+ }
+ // If the node's RLP isn't 32 bytes long, the node will not
+ // be hashed (and committed), and instead contain the rlp-encoding of the
+ // node. For the top level node, we need to force the hashing+commit.
+ hasher.sha.Reset()
+ hasher.sha.Write(st.val)
+ hasher.sha.Read(h[:])
+
+ st.writeFn(st.owner, nil, h, st.val)
+
+ return h, nil
+}
diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go
new file mode 100644
index 0000000000..c4080c5d67
--- /dev/null
+++ b/trie/stacktrie_test.go
@@ -0,0 +1,395 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "math/big"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+func TestStackTrieInsertAndHash(t *testing.T) {
+ type KeyValueHash struct {
+ K string // Hex string for key.
+ V string // Value, directly converted to bytes.
+ H string // Expected root hash after insert of (K, V) to an existing trie.
+ }
+ tests := [][]KeyValueHash{
+ { // {0:0, 7:0, f:0}
+ {"00", "v_______________________0___0", "5cb26357b95bb9af08475be00243ceb68ade0b66b5cd816b0c18a18c612d2d21"},
+ {"70", "v_______________________0___1", "8ff64309574f7a437a7ad1628e690eb7663cfde10676f8a904a8c8291dbc1603"},
+ {"f0", "v_______________________0___2", "9e3a01bd8d43efb8e9d4b5506648150b8e3ed1caea596f84ee28e01a72635470"},
+ },
+ { // {1:0cc, e:{1:fc, e:fc}}
+ {"10cc", "v_______________________1___0", "233e9b257843f3dfdb1cce6676cdaf9e595ac96ee1b55031434d852bc7ac9185"},
+ {"e1fc", "v_______________________1___1", "39c5e908ae83d0c78520c7c7bda0b3782daf594700e44546e93def8f049cca95"},
+ {"eefc", "v_______________________1___2", "d789567559fd76fe5b7d9cc42f3750f942502ac1c7f2a466e2f690ec4b6c2a7c"},
+ },
+ { // {b:{a:ac, b:ac}, d:acc}
+ {"baac", "v_______________________2___0", "8be1c86ba7ec4c61e14c1a9b75055e0464c2633ae66a055a24e75450156a5d42"},
+ {"bbac", "v_______________________2___1", "8495159b9895a7d88d973171d737c0aace6fe6ac02a4769fff1bc43bcccce4cc"},
+ {"dacc", "v_______________________2___2", "9bcfc5b220a27328deb9dc6ee2e3d46c9ebc9c69e78acda1fa2c7040602c63ca"},
+ },
+ { // {0:0cccc, 2:456{0:0, 2:2}
+ {"00cccc", "v_______________________3___0", "e57dc2785b99ce9205080cb41b32ebea7ac3e158952b44c87d186e6d190a6530"},
+ {"245600", "v_______________________3___1", "0335354adbd360a45c1871a842452287721b64b4234dfe08760b243523c998db"},
+ {"245622", "v_______________________3___2", "9e6832db0dca2b5cf81c0e0727bfde6afc39d5de33e5720bccacc183c162104e"},
+ },
+ { // {1:4567{1:1c, 3:3c}, 3:0cccccc}
+ {"1456711c", "v_______________________4___0", "f2389e78d98fed99f3e63d6d1623c1d4d9e8c91cb1d585de81fbc7c0e60d3529"},
+ {"1456733c", "v_______________________4___1", "101189b3fab852be97a0120c03d95eefcf984d3ed639f2328527de6def55a9c0"},
+ {"30cccccc", "v_______________________4___2", "3780ce111f98d15751dfde1eb21080efc7d3914b429e5c84c64db637c55405b3"},
+ },
+ { // 8800{1:f, 2:e, 3:d}
+ {"88001f", "v_______________________5___0", "e817db50d84f341d443c6f6593cafda093fc85e773a762421d47daa6ac993bd5"},
+ {"88002e", "v_______________________5___1", "d6e3e6047bdc110edd296a4d63c030aec451bee9d8075bc5a198eee8cda34f68"},
+ {"88003d", "v_______________________5___2", "b6bdf8298c703342188e5f7f84921a402042d0e5fb059969dd53a6b6b1fb989e"},
+ },
+ { // 0{1:fc, 2:ec, 4:dc}
+ {"01fc", "v_______________________6___0", "693268f2ca80d32b015f61cd2c4dba5a47a6b52a14c34f8e6945fad684e7a0d5"},
+ {"02ec", "v_______________________6___1", "e24ddd44469310c2b785a2044618874bf486d2f7822603a9b8dce58d6524d5de"},
+ {"04dc", "v_______________________6___2", "33fc259629187bbe54b92f82f0cd8083b91a12e41a9456b84fc155321e334db7"},
+ },
+ { // f{0:fccc, f:ff{0:f, f:f}}
+ {"f0fccc", "v_______________________7___0", "b0966b5aa469a3e292bc5fcfa6c396ae7a657255eef552ea7e12f996de795b90"},
+ {"ffff0f", "v_______________________7___1", "3b1ca154ec2a3d96d8d77bddef0abfe40a53a64eb03cecf78da9ec43799fa3d0"},
+ {"ffffff", "v_______________________7___2", "e75463041f1be8252781be0ace579a44ea4387bf5b2739f4607af676f7719678"},
+ },
+ { // ff{0:f{0:f, f:f}, f:fcc}
+ {"ff0f0f", "v_______________________8___0", "0928af9b14718ec8262ab89df430f1e5fbf66fac0fed037aff2b6767ae8c8684"},
+ {"ff0fff", "v_______________________8___1", "d870f4d3ce26b0bf86912810a1960693630c20a48ba56be0ad04bc3e9ddb01e6"},
+ {"ffffcc", "v_______________________8___2", "4239f10dd9d9915ecf2e047d6a576bdc1733ed77a30830f1bf29deaf7d8e966f"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"123f", "x___________________________2", "1164d7299964e74ac40d761f9189b2a3987fae959800d0f7e29d3aaf3eae9e15"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"124a", "x___________________________2", "661a96a669869d76b7231380da0649d013301425fbea9d5c5fae6405aa31cfce"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"13aa", "x___________________________2", "6590120e1fd3ffd1a90e8de5bb10750b61079bb0776cca4414dd79a24e4d4356"},
+ },
+ {
+ {"123d", "x___________________________0", "fc453d88b6f128a77c448669710497380fa4588abbea9f78f4c20c80daa797d0"},
+ {"123e", "x___________________________1", "5af48f2d8a9a015c1ff7fa8b8c7f6b676233bd320e8fb57fd7933622badd2cec"},
+ {"2aaa", "x___________________________2", "f869b40e0c55eace1918332ef91563616fbf0755e2b946119679f7ef8e44b514"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"1234fa", "x___________________________2", "4f4e368ab367090d5bc3dbf25f7729f8bd60df84de309b4633a6b69ab66142c0"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"1235aa", "x___________________________2", "21840121d11a91ac8bbad9a5d06af902a5c8d56a47b85600ba813814b7bfcb9b"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"124aaa", "x___________________________2", "ea4040ddf6ae3fbd1524bdec19c0ab1581015996262006632027fa5cf21e441e"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"13aaaa", "x___________________________2", "e4beb66c67e44f2dd8ba36036e45a44ff68f8d52942472b1911a45f886a34507"},
+ },
+ {
+ {"1234da", "x___________________________0", "1c4b4462e9f56a80ca0f5d77c0d632c41b0102290930343cf1791e971a045a79"},
+ {"1234ea", "x___________________________1", "2f502917f3ba7d328c21c8b45ee0f160652e68450332c166d4ad02d1afe31862"},
+ {"2aaaaa", "x___________________________2", "5f5989b820ff5d76b7d49e77bb64f26602294f6c42a1a3becc669cd9e0dc8ec9"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"1234fa", "x___________________________3", "65bb3aafea8121111d693ffe34881c14d27b128fd113fa120961f251fe28428d"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"1235aa", "x___________________________3", "f670e4d2547c533c5f21e0045442e2ecb733f347ad6d29ef36e0f5ba31bb11a8"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"124aaa", "x___________________________3", "c17464123050a9a6f29b5574bb2f92f6d305c1794976b475b7fb0316b6335598"},
+ },
+ {
+ {"000000", "x___________________________0", "3b32b7af0bddc7940e7364ee18b5a59702c1825e469452c8483b9c4e0218b55a"},
+ {"1234da", "x___________________________1", "3ab152a1285dca31945566f872c1cc2f17a770440eda32aeee46a5e91033dde2"},
+ {"1234ea", "x___________________________2", "0cccc87f96ddef55563c1b3be3c64fff6a644333c3d9cd99852cb53b6412b9b8"},
+ {"13aaaa", "x___________________________3", "aa8301be8cb52ea5cd249f5feb79fb4315ee8de2140c604033f4b3fff78f0105"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"123f", "x___________________________3", "80f7bad1893ca57e3443bb3305a517723a74d3ba831bcaca22a170645eb7aafb"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"124a", "x___________________________3", "383bc1bb4f019e6bc4da3751509ea709b58dd1ac46081670834bae072f3e9557"},
+ },
+ {
+ {"0000", "x___________________________0", "cb8c09ad07ae882136f602b3f21f8733a9f5a78f1d2525a8d24d1c13258000b2"},
+ {"123d", "x___________________________1", "8f09663deb02f08958136410dc48565e077f76bb6c9d8c84d35fc8913a657d31"},
+ {"123e", "x___________________________2", "0d230561e398c579e09a9f7b69ceaf7d3970f5a436fdb28b68b7a37c5bdd6b80"},
+ {"13aa", "x___________________________3", "ff0dc70ce2e5db90ee42a4c2ad12139596b890e90eb4e16526ab38fa465b35cf"},
+ },
+ }
+ st := NewStackTrie(nil)
+ for i, test := range tests {
+ // The StackTrie does not allow Insert(), Hash(), Insert(), ...
+ // so we will create new trie for every sequence length of inserts.
+ for l := 1; l <= len(test); l++ {
+ st.Reset()
+ for j := 0; j < l; j++ {
+ kv := &test[j]
+ if err := st.TryUpdate(types.StringToBytes(kv.K), []byte(kv.V)); err != nil {
+ t.Fatal(err)
+ }
+ }
+ expected := types.StringToHash(test[l-1].H)
+ if h := st.Hash(); h != expected {
+ t.Errorf("%d(%d): root hash mismatch: %x, expected %x", i, l, h, expected)
+ }
+ }
+ }
+}
+
+func TestSizeBug(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ leaf := types.StringToBytes("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ value := types.StringToBytes("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+
+ nt.TryUpdate(leaf, value)
+ st.TryUpdate(leaf, value)
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+func TestEmptyBug(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ //leaf := types.StringToBytes("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ //value := types.StringToBytes("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "9496f4ec2bf9dab484cac6be589e8417d84781be08"},
+ {K: "40edb63a35fcf86c08022722aa3287cdd36440d671b4918131b2514795fefa9c", V: "01"},
+ {K: "b10e2d527612073b26eecdfd717e6a320cf44b4afac2b0732d9fcbe2b7fa0cf6", V: "947a30f7736e48d6599356464ba4c150d8da0302ff"},
+ {K: "c2575a0e9e593c00f959f8c92f12db2869c3395a3b0502d05e2516446f71f85b", V: "02"},
+ }
+
+ for _, kv := range kvs {
+ nt.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ st.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ }
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+func TestValLength56(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ //leaf := types.StringToBytes("290decd9548b62a8d60345a988386fc84ba6bc95484008f6362f93160ef3e563")
+ //value := types.StringToBytes("94cf40d0d2b44f2b66e07cace1372ca42b73cf21a3")
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {K: "405787fa12a823e0f2b7631cc41b3ba8828b3321ca811111fa75cd3aa3bb5ace", V: "1111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111"},
+ }
+
+ for _, kv := range kvs {
+ nt.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ st.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ }
+
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestUpdateSmallNodes tests a case where the leaves are small (both key and value),
+// which causes a lot of node-within-node. This case was found via fuzzing.
+func TestUpdateSmallNodes(t *testing.T) {
+ st := NewStackTrie(nil)
+ nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {"63303030", "3041"}, // stacktrie.Update
+ {"65", "3000"}, // stacktrie.Update
+ }
+ for _, kv := range kvs {
+ nt.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ st.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ }
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestUpdateVariableKeys contains a case which stacktrie fails: when keys of different
+// sizes are used, and the second one has the same prefix as the first, then the
+// stacktrie fails, since it's unable to 'expand' on an already added leaf.
+// For all practical purposes, this is fine, since keys are fixed-size length
+// in account and storage tries.
+//
+// The test is marked as 'skipped', and exists just to have the behaviour documented.
+// This case was found via fuzzing.
+func TestUpdateVariableKeys(t *testing.T) {
+ t.SkipNow()
+ st := NewStackTrie(nil)
+ nt := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ kvs := []struct {
+ K string
+ V string
+ }{
+ {"0x33303534636532393561313031676174", "303030"},
+ {"0x3330353463653239356131303167617430", "313131"},
+ }
+ for _, kv := range kvs {
+ nt.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ st.TryUpdate(types.StringToBytes(kv.K), types.StringToBytes(kv.V))
+ }
+ if nt.Hash() != st.Hash() {
+ t.Fatalf("error %x != %x", st.Hash(), nt.Hash())
+ }
+}
+
+// TestStacktrieNotModifyValues checks that inserting blobs of data into the
+// stacktrie does not mutate the blobs
+func TestStacktrieNotModifyValues(t *testing.T) {
+ st := NewStackTrie(nil)
+ { // Test a very small trie
+ // Give it the value as a slice with large backing alloc,
+ // so if the stacktrie tries to append, it won't have to realloc
+ value := make([]byte, 1, 100)
+ value[0] = 0x2
+ want := types.CopyBytes(value)
+ st.TryUpdate([]byte{0x01}, value)
+ st.Hash()
+ if have := value; !bytes.Equal(have, want) {
+ t.Fatalf("tiny trie: have %#x want %#x", have, want)
+ }
+ st = NewStackTrie(nil)
+ }
+ // Test with a larger trie
+ keyB := big.NewInt(1)
+ keyDelta := big.NewInt(1)
+ var vals [][]byte
+ getValue := func(i int) []byte {
+ if i%2 == 0 { // large
+ return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
+ } else { //small
+ return big.NewInt(int64(i)).Bytes()
+ }
+ }
+ for i := 0; i < 1000; i++ {
+ key := types.BytesToHash(keyB.Bytes())
+ value := getValue(i)
+ st.TryUpdate(key.Bytes(), value)
+ vals = append(vals, value)
+ keyB = keyB.Add(keyB, keyDelta)
+ keyDelta.Add(keyDelta, types.Big1)
+ }
+ st.Hash()
+ for i := 0; i < 1000; i++ {
+ want := getValue(i)
+
+ have := vals[i]
+ if !bytes.Equal(have, want) {
+ t.Fatalf("item %d, have %#x want %#x", i, have, want)
+ }
+ }
+}
+
+// TestStacktrieSerialization tests that the stacktrie works well if we
+// serialize/unserialize it a lot
+func TestStacktrieSerialization(t *testing.T) {
+ var (
+ st = NewStackTrie(nil)
+ nt = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ keyB = big.NewInt(1)
+ keyDelta = big.NewInt(1)
+ vals [][]byte
+ keys [][]byte
+ )
+ getValue := func(i int) []byte {
+ if i%2 == 0 { // large
+ return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
+ } else { //small
+ return big.NewInt(int64(i)).Bytes()
+ }
+ }
+ for i := 0; i < 10; i++ {
+ vals = append(vals, getValue(i))
+ keys = append(keys, types.BytesToHash(keyB.Bytes()).Bytes())
+ keyB = keyB.Add(keyB, keyDelta)
+ keyDelta.Add(keyDelta, types.Big1)
+ }
+ for i, k := range keys {
+ nt.TryUpdate(k, types.CopyBytes(vals[i]))
+ }
+
+ for i, k := range keys {
+ blob, err := st.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ newSt, err := NewFromBinary(blob, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ st = newSt
+ st.TryUpdate(k, types.CopyBytes(vals[i]))
+ }
+ if have, want := st.Hash(), nt.Hash(); have != want {
+ t.Fatalf("have %#x want %#x", have, want)
+ }
+}
diff --git a/trie/trie.go b/trie/trie.go
new file mode 100644
index 0000000000..73181ef797
--- /dev/null
+++ b/trie/trie.go
@@ -0,0 +1,683 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+// Package trie implements Merkle Patricia Tries.
+package trie
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
+// top of a database. Whenever trie performs a commit operation, the generated
+// nodes will be gathered and returned in a set. Once the trie is committed,
+// it's not usable anymore. Callers have to re-create the trie with new root
+// based on the updated trie database.
+//
+// Trie is not safe for concurrent use.
+type Trie struct {
+ root node
+ owner types.Hash
+
+ // Keep track of the number leaves which have been inserted since the last
+ // hashing operation. This number will not directly map to the number of
+ // actually unhashed nodes.
+ unhashed int
+
+ // reader is the handler trie can retrieve nodes from.
+ reader *trieReader
+
+ // tracer is the tool to track the trie changes.
+ // It will be reset after each commit operation.
+ tracer *tracer
+
+ // logger for printing error log
+ logger Logger
+}
+
+// newFlag returns the cache flag value for a newly created node.
+func (t *Trie) newFlag() nodeFlag {
+ return nodeFlag{dirty: true}
+}
+
+// Copy returns a copy of Trie.
+func (t *Trie) Copy() *Trie {
+ return &Trie{
+ root: t.root,
+ owner: t.owner,
+ unhashed: t.unhashed,
+ reader: t.reader,
+ tracer: t.tracer.copy(),
+ }
+}
+
+// New creates the trie instance with provided trie id and the read-only
+// database. The state specified by trie id must be available, otherwise
+// an error will be returned. The trie root specified by trie id can be
+// zero hash or the sha3 hash of an empty string, then trie is initially
+// empty, otherwise, the root node must be present in database or returns
+// a MissingNodeError if not.
+func New(id *ID, db NodeReader, logger Logger) (*Trie, error) {
+ reader, err := newTrieReader(id.StateRoot, id.Owner, db)
+ if err != nil {
+ return nil, err
+ }
+
+ trie := &Trie{
+ owner: id.Owner,
+ reader: reader,
+ logger: logger,
+ }
+
+ if id.Root != types.ZeroHash && id.Root != types.EmptyRootHash {
+ rootnode, err := trie.resolveAndTrack(id.Root[:], nil)
+ if err != nil {
+ return nil, err
+ }
+
+ trie.root = rootnode
+ }
+
+ return trie, nil
+}
+
+// NewEmpty is a shortcut to create empty tree. It's mostly used in tests.
+func NewEmpty(db *Database) *Trie {
+ tr, _ := New(TrieID(types.Hash{}), db, hclog.NewNullLogger())
+
+ return tr
+}
+
+// NodeIterator returns an iterator that returns nodes of the trie. Iteration starts at
+// the key after the given start key.
+func (t *Trie) NodeIterator(start []byte) NodeIterator {
+ return newNodeIterator(t, start)
+}
+
+// Get returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+func (t *Trie) Get(key []byte) []byte {
+ res, err := t.TryGet(key)
+ if err != nil {
+ t.logger.Error("Unhandled trie error in Trie.Get", "err", err)
+ }
+
+ return res
+}
+
+// TryGet returns the value for key stored in the trie.
+// The value bytes must not be modified by the caller.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryGet(key []byte) ([]byte, error) {
+ value, newroot, didResolve, err := t.tryGet(t.root, keybytesToHex(key), 0)
+ if err == nil && didResolve {
+ t.root = newroot
+ }
+
+ return value, err
+}
+
+func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) {
+ switch n := (origNode).(type) {
+ case nil:
+ return nil, nil, false, nil
+ case valueNode:
+ return n, n, false, nil
+ case *shortNode:
+ if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) {
+ // key not found in trie
+ return nil, n, false, nil
+ }
+
+ value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key))
+ if err == nil && didResolve {
+ n = n.copy()
+ n.Val = newnode
+ }
+
+ return value, n, didResolve, err
+ case *fullNode:
+ value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
+ if err == nil && didResolve {
+ n = n.copy()
+ n.Children[key[pos]] = newnode
+ }
+
+ return value, n, didResolve, err
+ case hashNode:
+ child, err := t.resolveAndTrack(n, key[:pos])
+ if err != nil {
+ return nil, n, true, err
+ }
+
+ value, newnode, _, err := t.tryGet(child, key, pos)
+
+ return value, newnode, true, err
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
+ }
+}
+
+// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
+// possible to use keybyte-encoding as the path might contain odd nibbles.
+func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) {
+ item, newroot, resolved, err := t.tryGetNode(t.root, compactToHex(path), 0)
+ if err != nil {
+ return nil, resolved, err
+ }
+
+ if resolved > 0 {
+ t.root = newroot
+ }
+
+ if item == nil {
+ return nil, resolved, nil
+ }
+
+ return item, resolved, err
+}
+
+func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) {
+ // If non-existent path requested, abort
+ if origNode == nil {
+ return nil, nil, 0, nil
+ }
+
+ // If we reached the requested path, return the current node
+ if pos >= len(path) {
+ // Although we most probably have the original node expanded, encoding
+ // that into consensus form can be nasty (needs to cascade down) and
+ // time consuming. Instead, just pull the hash up from disk directly.
+ var hash hashNode
+ if node, ok := origNode.(hashNode); ok {
+ hash = node
+ } else {
+ hash, _ = origNode.cache()
+ }
+
+ if hash == nil {
+ return nil, origNode, 0, errors.New("non-consensus node")
+ }
+
+ blob, err := t.reader.nodeBlob(path, types.BytesToHash(hash))
+
+ return blob, origNode, 1, err
+ }
+ // Path still needs to be traversed, descend into children
+ switch n := (origNode).(type) {
+ case valueNode:
+ // Path prematurely ended, abort
+ return nil, nil, 0, nil
+ case *shortNode:
+ if len(path)-pos < len(n.Key) || !bytes.Equal(n.Key, path[pos:pos+len(n.Key)]) {
+ // Path branches off from short node
+ return nil, n, 0, nil
+ }
+
+ item, newnode, resolved, err = t.tryGetNode(n.Val, path, pos+len(n.Key))
+ if err == nil && resolved > 0 {
+ n = n.copy()
+ n.Val = newnode
+ }
+
+ return item, n, resolved, err
+ case *fullNode:
+ item, newnode, resolved, err = t.tryGetNode(n.Children[path[pos]], path, pos+1)
+ if err == nil && resolved > 0 {
+ n = n.copy()
+ n.Children[path[pos]] = newnode
+ }
+
+ return item, n, resolved, err
+ case hashNode:
+ child, err := t.resolveAndTrack(n, path[:pos])
+ if err != nil {
+ return nil, n, 1, err
+ }
+
+ item, newnode, resolved, err := t.tryGetNode(child, path, pos)
+
+ return item, newnode, resolved + 1, err
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode))
+ }
+}
+
+// Update associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+func (t *Trie) Update(key, value []byte) {
+ if err := t.TryUpdate(key, value); err != nil {
+ t.logger.Error("Unhandled trie error in Trie.Update", "err", err)
+ }
+}
+
+// TryUpdate associates key with value in the trie. Subsequent calls to
+// Get will return value. If value has length zero, any existing value
+// is deleted from the trie and calls to Get will return nil.
+//
+// The value bytes must not be modified by the caller while they are
+// stored in the trie.
+//
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryUpdate(key, value []byte) error {
+ return t.tryUpdate(key, value)
+}
+
+// tryUpdate expects an RLP-encoded value and performs the core function
+// for TryUpdate and TryUpdateAccount.
+func (t *Trie) tryUpdate(key, value []byte) error {
+ t.unhashed++
+
+ k := keybytesToHex(key)
+
+ if len(value) != 0 {
+ _, n, err := t.insert(t.root, nil, k, valueNode(value))
+ if err != nil {
+ return err
+ }
+
+ t.root = n
+ } else {
+ _, n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+
+ t.root = n
+ }
+
+ return nil
+}
+
+func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) {
+ if len(key) == 0 {
+ if v, ok := n.(valueNode); ok {
+ //nolint:forcetypeassert
+ return !bytes.Equal(v, value.(valueNode)), value, nil
+ }
+
+ return true, value, nil
+ }
+
+ switch n := n.(type) {
+ case *shortNode:
+ matchlen := prefixLen(key, n.Key)
+ // If the whole key matches, keep this short node as is
+ // and only update the value.
+ if matchlen == len(n.Key) {
+ dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value)
+ if !dirty || err != nil {
+ return false, n, err
+ }
+
+ return true, &shortNode{n.Key, nn, t.newFlag()}, nil
+ }
+ // Otherwise branch out at the index where they differ.
+ branch := &fullNode{flags: t.newFlag()}
+
+ var err error
+
+ _, branch.Children[n.Key[matchlen]], err = t.insert(
+ nil,
+ append(prefix, n.Key[:matchlen+1]...),
+ n.Key[matchlen+1:],
+ n.Val,
+ )
+ if err != nil {
+ return false, nil, err
+ }
+
+ _, branch.Children[key[matchlen]], err = t.insert(
+ nil,
+ append(prefix, key[:matchlen+1]...),
+ key[matchlen+1:],
+ value,
+ )
+ if err != nil {
+ return false, nil, err
+ }
+ // Replace this shortNode with the branch if it occurs at index 0.
+ if matchlen == 0 {
+ return true, branch, nil
+ }
+ // New branch node is created as a child of the original short node.
+ // Track the newly inserted node in the tracer. The node identifier
+ // passed is the path from the root node.
+ t.tracer.onInsert(append(prefix, key[:matchlen]...))
+
+ // Replace it with a short node leading up to the branch.
+ return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil
+ case *fullNode:
+ dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value)
+ if !dirty || err != nil {
+ return false, n, err
+ }
+
+ n = n.copy()
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
+
+ return true, n, nil
+ case nil:
+ // New short node is created and track it in the tracer. The node identifier
+ // passed is the path from the root node. Note the valueNode won't be tracked
+ // since it's always embedded in its parent.
+ t.tracer.onInsert(prefix)
+
+ return true, &shortNode{key, value, t.newFlag()}, nil
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and insert into it. This leaves all child nodes on
+ // the path to the value in the trie.
+ rn, err := t.resolveAndTrack(n, prefix)
+ if err != nil {
+ return false, nil, err
+ }
+
+ dirty, nn, err := t.insert(rn, prefix, key, value)
+ if !dirty || err != nil {
+ return false, rn, err
+ }
+
+ return true, nn, nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v", n, n))
+ }
+}
+
+// Delete removes any existing value for key from the trie.
+func (t *Trie) Delete(key []byte) {
+ if err := t.TryDelete(key); err != nil {
+ t.logger.Error("Unhandled trie error in Trie.Delete", "err", err)
+ }
+}
+
+// TryDelete removes any existing value for key from the trie.
+// If a node was not found in the database, a MissingNodeError is returned.
+func (t *Trie) TryDelete(key []byte) error {
+ t.unhashed++
+
+ k := keybytesToHex(key)
+
+ _, n, err := t.delete(t.root, nil, k)
+ if err != nil {
+ return err
+ }
+
+ t.root = n
+
+ return nil
+}
+
+// delete returns the new root of the trie with key deleted.
+// It reduces the trie to minimal form by simplifying
+// nodes on the way up after deleting recursively.
+func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
+ switch n := n.(type) {
+ case *shortNode:
+ matchlen := prefixLen(key, n.Key)
+ if matchlen < len(n.Key) {
+ return false, n, nil // don't replace n on mismatch
+ }
+
+ if matchlen == len(key) {
+ // The matched short node is deleted entirely and track
+ // it in the deletion set. The same the valueNode doesn't
+ // need to be tracked at all since it's always embedded.
+ t.tracer.onDelete(prefix)
+
+ return true, nil, nil // remove n entirely for whole matches
+ }
+ // The key is longer than n.Key. Remove the remaining suffix
+ // from the subtrie. Child can never be nil here since the
+ // subtrie must contain at least two other values with keys
+ // longer than n.Key.
+ dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):])
+ if !dirty || err != nil {
+ return false, n, err
+ }
+
+ switch child := child.(type) {
+ case *shortNode:
+ // The child shortNode is merged into its parent, track
+ // is deleted as well.
+ t.tracer.onDelete(append(prefix, n.Key...))
+
+ // Deleting from the subtrie reduced it to another
+ // short node. Merge the nodes to avoid creating a
+ // shortNode{..., shortNode{...}}. Use concat (which
+ // always creates a new slice) instead of append to
+ // avoid modifying n.Key since it might be shared with
+ // other nodes.
+ return true, &shortNode{concat(n.Key, child.Key...), child.Val, t.newFlag()}, nil
+ default:
+ return true, &shortNode{n.Key, child, t.newFlag()}, nil
+ }
+ case *fullNode:
+ dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:])
+ if !dirty || err != nil {
+ return false, n, err
+ }
+
+ n = n.copy()
+ n.flags = t.newFlag()
+ n.Children[key[0]] = nn
+
+ // Because n is a full node, it must've contained at least two children
+ // before the delete operation. If the new child value is non-nil, n still
+ // has at least two children after the deletion, and cannot be reduced to
+ // a short node.
+ if nn != nil {
+ return true, n, nil
+ }
+ // Reduction:
+ // Check how many non-nil entries are left after deleting and
+ // reduce the full node to a short node if only one entry is
+ // left. Since n must've contained at least two children
+ // before deletion (otherwise it would not be a full node) n
+ // can never be reduced to nil.
+ //
+ // When the loop is done, pos contains the index of the single
+ // value that is left in n or -2 if n contains at least two
+ // values.
+ pos := -1
+
+ for i, cld := range &n.Children {
+ if cld != nil {
+ if pos == -1 {
+ pos = i
+ } else {
+ pos = -2
+
+ break
+ }
+ }
+ }
+
+ if pos >= 0 {
+ if pos != 16 {
+ // If the remaining entry is a short node, it replaces
+ // n and its key gets the missing nibble tacked to the
+ // front. This avoids creating an invalid
+ // shortNode{..., shortNode{...}}. Since the entry
+ // might not be loaded yet, resolve it just for this
+ // check.
+ cnode, err := t.resolve(n.Children[pos], append(prefix, byte(pos)))
+ if err != nil {
+ return false, nil, err
+ }
+
+ if cnode, ok := cnode.(*shortNode); ok {
+ // Replace the entire full node with the short node.
+ // Mark the original short node as deleted since the
+ // value is embedded into the parent now.
+ t.tracer.onDelete(append(prefix, byte(pos)))
+
+ k := append([]byte{byte(pos)}, cnode.Key...)
+
+ return true, &shortNode{k, cnode.Val, t.newFlag()}, nil
+ }
+ }
+ // Otherwise, n is replaced by a one-nibble short node
+ // containing the child.
+ return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil
+ }
+
+ // n still contains at least two values and cannot be reduced.
+ return true, n, nil
+ case valueNode:
+ return true, nil, nil
+ case nil:
+ return false, nil, nil
+ case hashNode:
+ // We've hit a part of the trie that isn't loaded yet. Load
+ // the node and delete from it. This leaves all child nodes on
+ // the path to the value in the trie.
+ rn, err := t.resolveAndTrack(n, prefix)
+ if err != nil {
+ return false, nil, err
+ }
+
+ dirty, nn, err := t.delete(rn, prefix, key)
+ if !dirty || err != nil {
+ return false, rn, err
+ }
+
+ return true, nn, nil
+ default:
+ panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key))
+ }
+}
+
+func concat(s1 []byte, s2 ...byte) []byte {
+ r := make([]byte, len(s1)+len(s2))
+ copy(r, s1)
+ copy(r[len(s1):], s2)
+
+ return r
+}
+
+func (t *Trie) resolve(n node, prefix []byte) (node, error) {
+ if n, ok := n.(hashNode); ok {
+ return t.resolveAndTrack(n, prefix)
+ }
+
+ return n, nil
+}
+
+// resolveAndTrack loads node from the underlying store with the given node hash
+// and path prefix and also tracks the loaded node blob in tracer treated as the
+// node's original value. The rlp-encoded blob is preferred to be loaded from
+// database because it's easy to decode node while complex to encode node to blob.
+func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
+ blob, err := t.reader.nodeBlob(prefix, types.BytesToHash(n))
+ if err != nil {
+ return nil, err
+ }
+
+ t.tracer.onRead(prefix, blob)
+
+ return mustDecodeNode(n, blob), nil
+}
+
+// Hash returns the root hash of the trie. It does not write to the
+// database and can be used even if the trie doesn't have one.
+func (t *Trie) Hash() types.Hash {
+ hash, cached, _ := t.hashRoot()
+ t.root = cached
+
+ //nolint:forcetypeassert
+ return types.BytesToHash(hash.(hashNode))
+}
+
+// Commit collects all dirty nodes in the trie and replaces them with the
+// corresponding node hash. All collected nodes (including dirty leaves if
+// collectLeaf is true) will be encapsulated into a nodeset for return.
+// The returned nodeset can be nil if the trie is clean (nothing to commit).
+// Once the trie is committed, it's not usable anymore. A new trie must
+// be created with new root and updated trie database for following usage
+func (t *Trie) Commit(collectLeaf bool) (types.Hash, *NodeSet, error) {
+ defer t.tracer.reset()
+
+ // Trie is empty and can be classified into two types of situations:
+ // - The trie was empty and no update happens
+ // - The trie was non-empty and all nodes are dropped
+ if t.root == nil {
+ // Wrap tracked deletions as the return
+ set := NewNodeSet(t.owner)
+ t.tracer.markDeletions(set)
+
+ return types.EmptyRootHash, set, nil
+ }
+
+ // Derive the hash for all dirty nodes first. We hold the assumption
+ // in the following procedure that all nodes are hashed.
+ rootHash := t.Hash()
+
+ // Do a quick check if we really need to commit. This can happen e.g.
+ // if we load a trie for reading storage values, but don't write to it.
+ if hashedNode, dirty := t.root.cache(); !dirty {
+ // Replace the root node with the origin hash in order to
+ // ensure all resolved nodes are dropped after the commit.
+ t.root = hashedNode
+
+ return rootHash, nil, nil
+ }
+
+ h := newCommitter(t.owner, t.tracer, collectLeaf)
+
+ newRoot, nodes, err := h.Commit(t.root)
+ if err != nil {
+ return types.Hash{}, nil, err
+ }
+
+ t.root = newRoot
+
+ return rootHash, nodes, nil
+}
+
+// hashRoot calculates the root hash of the given trie
+func (t *Trie) hashRoot() (node, node, error) {
+ if t.root == nil {
+ return hashNode(types.EmptyRootHash.Bytes()), nil, nil
+ }
+
+ // If the number of changes is below 100, we let one thread handle it
+ h := newHasher(t.unhashed >= 100)
+ defer returnHasherToPool(h)
+
+ hashed, cached := h.hash(t.root, true)
+ t.unhashed = 0
+
+ return hashed, cached, nil
+}
+
+// Reset drops the referenced root node and cleans all internal state.
+func (t *Trie) Reset() {
+ t.root = nil
+ t.owner = types.Hash{}
+ t.unhashed = 0
+ t.tracer.reset()
+}
diff --git a/trie/trie_id.go b/trie/trie_id.go
new file mode 100644
index 0000000000..de8b70a0ef
--- /dev/null
+++ b/trie/trie_id.go
@@ -0,0 +1,55 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see
+
+package trie
+
+import "github.com/dogechain-lab/dogechain/types"
+
+// ID is the identifier for uniquely identifying a trie.
+type ID struct {
+ StateRoot types.Hash // The root of the corresponding state(block.root)
+ Owner types.Hash // The contract address hash which the trie belongs to
+ Root types.Hash // The root hash of trie
+}
+
+// StateTrieID constructs an identifier for state trie with the provided state root.
+func StateTrieID(root types.Hash) *ID {
+ return &ID{
+ StateRoot: root,
+ Owner: types.Hash{},
+ Root: root,
+ }
+}
+
+// StorageTrieID constructs an identifier for storage trie which belongs to a certain
+// state and contract specified by the stateRoot and owner.
+func StorageTrieID(stateRoot types.Hash, owner types.Hash, root types.Hash) *ID {
+ return &ID{
+ StateRoot: stateRoot,
+ Owner: owner,
+ Root: root,
+ }
+}
+
+// TrieID constructs an identifier for a standard trie(not a second-layer trie)
+// with provided root. It's mostly used in tests and some other tries like CHT trie.
+func TrieID(root types.Hash) *ID {
+ return &ID{
+ StateRoot: root,
+ Owner: types.Hash{},
+ Root: root,
+ }
+}
diff --git a/trie/trie_reader.go b/trie/trie_reader.go
new file mode 100644
index 0000000000..18d60b58b6
--- /dev/null
+++ b/trie/trie_reader.go
@@ -0,0 +1,113 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "fmt"
+
+ "github.com/dogechain-lab/dogechain/types"
+)
+
+// Reader wraps the Node and NodeBlob method of a backing trie store.
+type Reader interface {
+ // Node retrieves the trie node with the provided trie identifier, hexary
+ // node path and the corresponding node hash.
+ // No error will be returned if the node is not found.
+ Node(owner types.Hash, path []byte, hash types.Hash) (node, error)
+
+ // NodeBlob retrieves the RLP-encoded trie node blob with the provided trie
+ // identifier, hexary node path and the corresponding node hash.
+ // No error will be returned if the node is not found.
+ NodeBlob(owner types.Hash, path []byte, hash types.Hash) ([]byte, error)
+}
+
+// NodeReader wraps all the necessary functions for accessing trie node.
+type NodeReader interface {
+ // GetReader returns a reader for accessing all trie nodes with provided
+ // state root. Nil is returned in case the state is not available.
+ GetReader(root types.Hash) Reader
+}
+
+// trieReader is a wrapper of the underlying node reader. It's not safe
+// for concurrent usage.
+type trieReader struct {
+ owner types.Hash
+ reader Reader
+ banned map[string]struct{} // Marker to prevent node from being accessed, for tests
+}
+
+// newTrieReader initializes the trie reader with the given node reader.
+func newTrieReader(stateRoot, owner types.Hash, db NodeReader) (*trieReader, error) {
+ reader := db.GetReader(stateRoot)
+ if reader == nil {
+ return nil, fmt.Errorf("state not found #%x", stateRoot)
+ }
+
+ return &trieReader{owner: owner, reader: reader}, nil
+}
+
+// newEmptyReader initializes the pure in-memory reader. All read operations
+// should be forbidden and returns the MissingNodeError.
+func newEmptyReader() *trieReader {
+ return &trieReader{}
+}
+
+// node retrieves the trie node with the provided trie node information.
+// An MissingNodeError will be returned in case the node is not found or
+// any error is encountered.
+func (r *trieReader) node(path []byte, hash types.Hash) (node, error) {
+ // Perform the logics in tests for preventing trie node access.
+ if r.banned != nil {
+ if _, ok := r.banned[string(path)]; ok {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ }
+
+ if r.reader == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+
+ node, err := r.reader.Node(r.owner, path, hash)
+ if err != nil || node == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
+ }
+
+ return node, nil
+}
+
+// node retrieves the rlp-encoded trie node with the provided trie node
+// information. An MissingNodeError will be returned in case the node is
+// not found or any error is encountered.
+func (r *trieReader) nodeBlob(path []byte, hash types.Hash) ([]byte, error) {
+ // Perform the logics in tests for preventing trie node access.
+ if r.banned != nil {
+ if _, ok := r.banned[string(path)]; ok {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+ }
+
+ if r.reader == nil {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path}
+ }
+
+ blob, err := r.reader.NodeBlob(r.owner, path, hash)
+ if err != nil || len(blob) == 0 {
+ return nil, &MissingNodeError{Owner: r.owner, NodeHash: hash, Path: path, err: err}
+ }
+
+ return blob, nil
+}
diff --git a/trie/trie_test.go b/trie/trie_test.go
new file mode 100644
index 0000000000..c7ea4cf59a
--- /dev/null
+++ b/trie/trie_test.go
@@ -0,0 +1,1200 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "hash"
+ "math/big"
+ "math/rand"
+ "reflect"
+ "testing"
+ "testing/quick"
+
+ "github.com/davecgh/go-spew/spew"
+ "github.com/dogechain-lab/dogechain/crypto"
+ "github.com/dogechain-lab/dogechain/helper/kvdb"
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/state/stypes"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+ "golang.org/x/crypto/sha3"
+)
+
+func init() {
+ spew.Config.Indent = " "
+ spew.Config.DisableMethods = false
+}
+
+func TestEmptyTrie(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ res := trie.Hash()
+ exp := types.EmptyRootHash
+ if res != exp {
+ t.Errorf("expected %x got %x", exp, res)
+ }
+}
+
+func TestNull(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ key := make([]byte, 32)
+ value := []byte("test")
+ trie.Update(key, value)
+ if !bytes.Equal(trie.Get(key), value) {
+ t.Fatal("wrong value")
+ }
+}
+
+func TestMissingRoot(t *testing.T) {
+ root := types.StringToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33")
+ logger := hclog.NewNullLogger()
+ trie, err := New(TrieID(root), NewDatabase(rawdb.NewMemoryDatabase(), logger), logger)
+ if trie != nil {
+ t.Error("New returned non-nil trie for invalid root")
+ }
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("New returned wrong error: %v", err)
+ }
+}
+
+func TestMissingNodeDisk(t *testing.T) { testMissingNode(t, false) }
+func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
+
+func testMissingNode(t *testing.T, memonly bool) {
+ t.Helper()
+
+ diskdb := rawdb.NewMemoryDatabase()
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(diskdb, logger)
+
+ trie := NewEmpty(triedb)
+ updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
+ updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
+ root, nodes, _ := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ if !memonly {
+ triedb.Commit(root, true, nil)
+ }
+
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err := trie.TryGet([]byte("120000"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err = trie.TryGet([]byte("120099"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ err = trie.TryDelete([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ hash := types.StringToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+ if memonly {
+ delete(triedb.dirties, hash)
+ } else {
+ diskdb.Delete(hash[:])
+ }
+
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err = trie.TryGet([]byte("120000"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err = trie.TryGet([]byte("120099"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ _, err = trie.TryGet([]byte("123456"))
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+ trie, _ = New(TrieID(root), triedb, logger)
+ err = trie.TryDelete([]byte("123456"))
+ if _, ok := err.(*MissingNodeError); !ok {
+ t.Errorf("Wrong error: %v", err)
+ }
+}
+
+func TestInsert(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), logger))
+
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ exp := types.StringToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
+ root := trie.Hash()
+ if root != exp {
+ t.Errorf("case 1: exp %x got %x", exp, root)
+ }
+
+ trie = NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), logger))
+ updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
+
+ exp = types.StringToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
+ root, _, err := trie.Commit(false)
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if root != exp {
+ t.Errorf("case 2: exp %x got %x", exp, root)
+ }
+}
+
+func TestGet(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ db := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(db)
+ updateString(trie, "doe", "reindeer")
+ updateString(trie, "dog", "puppy")
+ updateString(trie, "dogglesworth", "cat")
+
+ for i := 0; i < 2; i++ {
+ res := getString(trie, "dog")
+ if !bytes.Equal(res, []byte("puppy")) {
+ t.Errorf("expected puppy got %x", res)
+ }
+ unknown := getString(trie, "unknown")
+ if unknown != nil {
+ t.Errorf("expected nil got %x", unknown)
+ }
+ if i == 1 {
+ return
+ }
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+ trie, _ = New(TrieID(root), db, logger)
+ }
+}
+
+func TestDelete(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ if val.v != "" {
+ updateString(trie, val.k, val.v)
+ } else {
+ deleteString(trie, val.k)
+ }
+ }
+
+ hash := trie.Hash()
+ exp := types.StringToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestEmptyValues(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"ether", ""},
+ {"dog", "puppy"},
+ {"shaman", ""},
+ }
+ for _, val := range vals {
+ updateString(trie, val.k, val.v)
+ }
+
+ hash := trie.Hash()
+ exp := types.StringToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
+ if hash != exp {
+ t.Errorf("expected %x got %x", exp, hash)
+ }
+}
+
+func TestReplication(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(triedb)
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ updateString(trie, val.k, val.v)
+ }
+ exp, nodes, err := trie.Commit(false)
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ triedb.Update(NewWithNodeSet(nodes))
+
+ // create a new trie on top of the database and check that lookups work.
+ trie2, err := New(TrieID(exp), triedb, logger)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ for _, kv := range vals {
+ if string(getString(trie2, kv.k)) != kv.v {
+ t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
+ }
+ }
+ hash, nodes, err := trie2.Commit(false)
+ if err != nil {
+ t.Fatalf("commit error: %v", err)
+ }
+ if hash != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+
+ // recreate the trie after commit
+ if nodes != nil {
+ triedb.Update(NewWithNodeSet(nodes))
+ }
+ trie2, err = New(TrieID(hash), triedb, logger)
+ if err != nil {
+ t.Fatalf("can't recreate trie at %x: %v", exp, err)
+ }
+ // perform some insertions on the new trie.
+ vals2 := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ // {"shaman", "horse"},
+ // {"doge", "coin"},
+ // {"ether", ""},
+ // {"dog", "puppy"},
+ // {"somethingveryoddindeedthis is", "myothernodedata"},
+ // {"shaman", ""},
+ }
+ for _, val := range vals2 {
+ updateString(trie2, val.k, val.v)
+ }
+ if hash := trie2.Hash(); hash != exp {
+ t.Errorf("root failure. expected %x got %x", exp, hash)
+ }
+}
+
+func TestLargeValue(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
+ trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
+ trie.Hash()
+}
+
+// TestRandomCases tests som cases that were found via random fuzzing
+func TestRandomCases(t *testing.T) {
+ var rt = []randTestStep{
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 0
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 1
+ {op: 0, key: types.StringToBytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: types.StringToBytes("0000000000000002")}, // step 2
+ {op: 2, key: types.StringToBytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: types.StringToBytes("")}, // step 3
+ {op: 3, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 4
+ {op: 3, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 5
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 6
+ {op: 3, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 7
+ {op: 0, key: types.StringToBytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: types.StringToBytes("0000000000000008")}, // step 8
+ {op: 0, key: types.StringToBytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: types.StringToBytes("0000000000000009")}, // step 9
+ {op: 2, key: types.StringToBytes("fd"), value: types.StringToBytes("")}, // step 10
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 11
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 12
+ {op: 0, key: types.StringToBytes("fd"), value: types.StringToBytes("000000000000000d")}, // step 13
+ {op: 6, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 14
+ {op: 1, key: types.StringToBytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: types.StringToBytes("")}, // step 15
+ {op: 3, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 16
+ {op: 0, key: types.StringToBytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: types.StringToBytes("0000000000000011")}, // step 17
+ {op: 5, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 18
+ {op: 3, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 19
+ {op: 0, key: types.StringToBytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: types.StringToBytes("0000000000000014")}, // step 20
+ {op: 0, key: types.StringToBytes("d51b182b95d677e5f1c82508c0228de96b73092d78ce78b2230cd948674f66fd1483bd"), value: types.StringToBytes("0000000000000015")}, // step 21
+ {op: 0, key: types.StringToBytes("c2a38512b83107d665c65235b0250002882ac2022eb00711552354832c5f1d030d0e408e"), value: types.StringToBytes("0000000000000016")}, // step 22
+ {op: 5, key: types.StringToBytes(""), value: types.StringToBytes("")}, // step 23
+ {op: 1, key: types.StringToBytes("980c393656413a15c8da01978ed9f89feb80b502f58f2d640e3a2f5f7a99a7018f1b573befd92053ac6f78fca4a87268"), value: types.StringToBytes("")}, // step 24
+ {op: 1, key: types.StringToBytes("fd"), value: types.StringToBytes("")}, // step 25
+ }
+ runRandTest(rt)
+}
+
+// randTest performs random trie operations.
+// Instances of this test are created by Generate.
+type randTest []randTestStep
+
+type randTestStep struct {
+ op int
+ key []byte // for opUpdate, opDelete, opGet
+ value []byte // for opUpdate
+ err error // for debugging
+}
+
+const (
+ opUpdate = iota
+ opDelete
+ opGet
+ opHash
+ opCommit
+ opItercheckhash
+ opNodeDiff
+ opProve
+ opMax // boundary value, not an actual op
+)
+
+func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
+ var allKeys [][]byte
+ genKey := func() []byte {
+ if len(allKeys) < 2 || r.Intn(100) < 10 {
+ // new key
+ key := make([]byte, r.Intn(50))
+ r.Read(key)
+ allKeys = append(allKeys, key)
+ return key
+ }
+ // use existing key
+ return allKeys[r.Intn(len(allKeys))]
+ }
+
+ var steps randTest
+ for i := 0; i < size; i++ {
+ step := randTestStep{op: r.Intn(opMax)}
+ switch step.op {
+ case opUpdate:
+ step.key = genKey()
+ step.value = make([]byte, 8)
+ binary.BigEndian.PutUint64(step.value, uint64(i))
+ case opGet, opDelete, opProve:
+ step.key = genKey()
+ }
+ steps = append(steps, step)
+ }
+ return reflect.ValueOf(steps)
+}
+
+func runRandTest(rt randTest) bool {
+ var (
+ logger = hclog.NewNullLogger()
+ triedb = NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ tr = NewEmpty(triedb)
+ values = make(map[string]string) // tracks content of the trie
+ origTrie = NewEmpty(triedb)
+ )
+ tr.tracer = newTracer()
+
+ for i, step := range rt {
+ // fmt.Printf("{op: %d, key: types.StringToBytes(\"%x\"), value: types.StringToBytes(\"%x\")}, // step %d\n",
+ // step.op, step.key, step.value, i)
+
+ switch step.op {
+ case opUpdate:
+ tr.Update(step.key, step.value)
+ values[string(step.key)] = string(step.value)
+ case opDelete:
+ tr.Delete(step.key)
+ delete(values, string(step.key))
+ case opGet:
+ v := tr.Get(step.key)
+ want := values[string(step.key)]
+ if string(v) != want {
+ rt[i].err = fmt.Errorf("mismatch for key %#x, got %#x want %#x", step.key, v, want)
+ }
+ case opProve:
+ hash := tr.Hash()
+ if hash == types.EmptyRootHash {
+ continue
+ }
+ proofDB := rawdb.NewMemoryDatabase()
+ err := tr.Prove(step.key, 0, proofDB)
+ if err != nil {
+ rt[i].err = fmt.Errorf("failed for proving key %#x, %v", step.key, err)
+ }
+ _, err = VerifyProof(hash, step.key, proofDB)
+ if err != nil {
+ rt[i].err = fmt.Errorf("failed for verifying key %#x, %v", step.key, err)
+ }
+ case opHash:
+ tr.Hash()
+ case opCommit:
+ root, nodes, err := tr.Commit(true)
+ if err != nil {
+ rt[i].err = err
+ return false
+ }
+ // Validity the returned nodeset
+ if nodes != nil {
+ for path, node := range nodes.updates.nodes {
+ blob, _, _ := origTrie.TryGetNode(hexToCompact([]byte(path)))
+ got := node.prev
+ if !bytes.Equal(blob, got) {
+ rt[i].err = fmt.Errorf("prevalue mismatch for 0x%x, got 0x%x want 0x%x", path, got, blob)
+ panic(rt[i].err)
+ }
+ }
+ for path, prev := range nodes.deletes {
+ blob, _, _ := origTrie.TryGetNode(hexToCompact([]byte(path)))
+ if !bytes.Equal(blob, prev) {
+ rt[i].err = fmt.Errorf("prevalue mismatch for 0x%x, got 0x%x want 0x%x", path, prev, blob)
+ return false
+ }
+ }
+ }
+ if nodes != nil {
+ triedb.Update(NewWithNodeSet(nodes))
+ }
+ newtr, err := New(TrieID(root), triedb, logger)
+ if err != nil {
+ rt[i].err = err
+ return false
+ }
+ tr = newtr
+
+ // Enable node tracing. Resolve the root node again explicitly
+ // since it's not captured at the beginning.
+ tr.tracer = newTracer()
+ tr.resolveAndTrack(root.Bytes(), nil)
+
+ origTrie = tr.Copy()
+ case opItercheckhash:
+ checktr := NewEmpty(triedb)
+ it := NewIterator(tr.NodeIterator(nil))
+ for it.Next() {
+ checktr.Update(it.Key, it.Value)
+ }
+ if tr.Hash() != checktr.Hash() {
+ rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
+ }
+ case opNodeDiff:
+ var (
+ inserted = tr.tracer.insertList()
+ deleted = tr.tracer.deleteList()
+ origIter = origTrie.NodeIterator(nil)
+ curIter = tr.NodeIterator(nil)
+ origSeen = make(map[string]struct{})
+ curSeen = make(map[string]struct{})
+ )
+ for origIter.Next(true) {
+ if origIter.Leaf() {
+ continue
+ }
+ origSeen[string(origIter.Path())] = struct{}{}
+ }
+ for curIter.Next(true) {
+ if curIter.Leaf() {
+ continue
+ }
+ curSeen[string(curIter.Path())] = struct{}{}
+ }
+ var (
+ insertExp = make(map[string]struct{})
+ deleteExp = make(map[string]struct{})
+ )
+ for path := range curSeen {
+ _, present := origSeen[path]
+ if !present {
+ insertExp[path] = struct{}{}
+ }
+ }
+ for path := range origSeen {
+ _, present := curSeen[path]
+ if !present {
+ deleteExp[path] = struct{}{}
+ }
+ }
+ if len(insertExp) != len(inserted) {
+ rt[i].err = fmt.Errorf("insert set mismatch")
+ }
+ if len(deleteExp) != len(deleted) {
+ rt[i].err = fmt.Errorf("delete set mismatch")
+ }
+ for _, insert := range inserted {
+ if _, present := insertExp[string(insert)]; !present {
+ rt[i].err = fmt.Errorf("missing inserted node")
+ }
+ }
+ for _, del := range deleted {
+ if _, present := deleteExp[string(del)]; !present {
+ rt[i].err = fmt.Errorf("missing deleted node")
+ }
+ }
+ }
+ // Abort the test on error.
+ if rt[i].err != nil {
+ return false
+ }
+ }
+ return true
+}
+
+func TestRandom(t *testing.T) {
+ if err := quick.Check(runRandTest, nil); err != nil {
+ if cerr, ok := err.(*quick.CheckError); ok {
+ t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
+ }
+ t.Fatal(err)
+ }
+}
+
+func BenchmarkGet(b *testing.B) { benchGet(b) }
+func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
+func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
+
+const benchElemCount = 20000
+
+func benchGet(b *testing.B) {
+ b.Helper()
+
+ triedb := NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger())
+ trie := NewEmpty(triedb)
+ k := make([]byte, 32)
+ for i := 0; i < benchElemCount; i++ {
+ binary.LittleEndian.PutUint64(k, uint64(i))
+ trie.Update(k, k)
+ }
+ binary.LittleEndian.PutUint64(k, benchElemCount/2)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ trie.Get(k)
+ }
+ b.StopTimer()
+}
+
+func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
+ b.Helper()
+
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ k := make([]byte, 32)
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ e.PutUint64(k, uint64(i))
+ trie.Update(k, k)
+ }
+ return trie
+}
+
+// Benchmarks the trie hashing. Since the trie caches the result of any operation,
+// we cannot use b.N as the number of hashing rounds, since all rounds apart from
+// the first one will be NOOP. As such, we'll use b.N as the number of account to
+// insert into the trie before measuring the hashing.
+// BenchmarkHash-6 288680 4561 ns/op 682 B/op 9 allocs/op
+// BenchmarkHash-6 275095 4800 ns/op 685 B/op 9 allocs/op
+// pure hasher:
+// BenchmarkHash-6 319362 4230 ns/op 675 B/op 9 allocs/op
+// BenchmarkHash-6 257460 4674 ns/op 689 B/op 9 allocs/op
+// With hashing in-between and pure hasher:
+// BenchmarkHash-6 225417 7150 ns/op 982 B/op 12 allocs/op
+// BenchmarkHash-6 220378 6197 ns/op 983 B/op 12 allocs/op
+// same with old hasher
+// BenchmarkHash-6 229758 6437 ns/op 981 B/op 12 allocs/op
+// BenchmarkHash-6 212610 7137 ns/op 986 B/op 12 allocs/op
+func BenchmarkHash(b *testing.B) {
+ // Create a realistic account trie to hash. We're first adding and hashing N
+ // entries, then adding N more.
+ addresses, accounts := makeAccounts(2 * b.N)
+ // Insert the accounts into the trie and hash it
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ i := 0
+ for ; i < len(addresses)/2; i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ trie.Hash()
+ for ; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ b.ResetTimer()
+ b.ReportAllocs()
+ //trie.hashRoot(nil, nil)
+ trie.Hash()
+}
+
+// Benchmarks the trie Commit following a Hash. Since the trie caches the result of any operation,
+// we cannot use b.N as the number of hashing rounds, since all rounds apart from
+// the first one will be NOOP. As such, we'll use b.N as the number of account to
+// insert into the trie before measuring the hashing.
+func BenchmarkCommitAfterHash(b *testing.B) {
+ b.Run("no-onleaf", func(b *testing.B) {
+ benchmarkCommitAfterHash(b, false)
+ })
+ b.Run("with-onleaf", func(b *testing.B) {
+ benchmarkCommitAfterHash(b, true)
+ })
+}
+
+func benchmarkCommitAfterHash(b *testing.B, collectLeaf bool) {
+ b.Helper()
+
+ // Make the random benchmark deterministic
+ addresses, accounts := makeAccounts(b.N)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ b.ResetTimer()
+ b.ReportAllocs()
+ trie.Commit(collectLeaf)
+}
+
+func TestTinyTrie(t *testing.T) {
+ // Create a realistic account trie to hash
+ _, accounts := makeAccounts(5)
+ logger := hclog.NewNullLogger()
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), logger))
+ trie.Update(types.StringToBytes("0000000000000000000000000000000000000000000000000000000000001337"), accounts[3])
+ if exp, root := types.StringToHash("8c6a85a4d9fda98feff88450299e574e5378e32391f75a055d470ac0653f1005"), trie.Hash(); exp != root {
+ t.Errorf("1: got %x, exp %x", root, exp)
+ }
+ trie.Update(types.StringToBytes("0000000000000000000000000000000000000000000000000000000000001338"), accounts[4])
+ if exp, root := types.StringToHash("ec63b967e98a5720e7f720482151963982890d82c9093c0d486b7eb8883a66b1"), trie.Hash(); exp != root {
+ t.Errorf("2: got %x, exp %x", root, exp)
+ }
+ trie.Update(types.StringToBytes("0000000000000000000000000000000000000000000000000000000000001339"), accounts[4])
+ if exp, root := types.StringToHash("0608c1d1dc3905fa22204c7a0e43644831c3b6d3def0f274be623a948197e64a"), trie.Hash(); exp != root {
+ t.Errorf("3: got %x, exp %x", root, exp)
+ }
+ checktr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), logger))
+ it := NewIterator(trie.NodeIterator(nil))
+ for it.Next() {
+ checktr.Update(it.Key, it.Value)
+ }
+ if troot, itroot := trie.Hash(), checktr.Hash(); troot != itroot {
+ t.Fatalf("hash mismatch in opItercheckhash, trie: %x, check: %x", troot, itroot)
+ }
+}
+
+func TestCommitAfterHash(t *testing.T) {
+ // Create a realistic account trie to hash
+ addresses, accounts := makeAccounts(1000)
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ trie.Commit(false)
+ root := trie.Hash()
+ exp := types.StringToHash("72f9d3f3fe1e1dd7b8936442e7642aef76371472d94319900790053c493f3fe6")
+ if exp != root {
+ t.Errorf("got %x, exp %x", root, exp)
+ }
+ root, _, _ = trie.Commit(false)
+ if exp != root {
+ t.Errorf("got %x, exp %x", root, exp)
+ }
+}
+
+func makeAccounts(size int) (addresses [][20]byte, accounts [][]byte) {
+ // Make the random benchmark deterministic
+ random := rand.New(rand.NewSource(0))
+ // Create a realistic account trie to hash
+ addresses = make([][20]byte, size)
+ for i := 0; i < len(addresses); i++ {
+ data := make([]byte, 20)
+ random.Read(data)
+ copy(addresses[i][:], data)
+ }
+ accounts = make([][]byte, len(addresses))
+ for i := 0; i < len(accounts); i++ {
+ var (
+ nonce = uint64(random.Int63())
+ root = types.EmptyRootHash
+ code = crypto.Keccak256(nil)
+ )
+ // The big.Rand function is not deterministic with regards to 64 vs 32 bit systems,
+ // and will consume different amount of data from the rand source.
+ //balance = new(big.Int).Rand(random, new(big.Int).Exp(types.Big2, types.Big256, nil))
+ // Therefore, we instead just read via byte buffer
+ numBytes := random.Uint32() % 33 // [0, 32] bytes
+ balanceBytes := make([]byte, numBytes)
+ random.Read(balanceBytes)
+ balance := new(big.Int).SetBytes(balanceBytes)
+ // account
+ acc := &stypes.Account{Nonce: nonce, Balance: balance, StorageRoot: root, CodeHash: code}
+ // marshal
+ ar := accountArenaPool.Get()
+ vv := acc.MarshalWith(ar)
+ accountArenaPool.Put(ar)
+ data := vv.MarshalTo(nil)
+ // set it
+ accounts[i] = data
+ }
+ return addresses, accounts
+}
+
+// spongeDB is a dummy db backend which accumulates writes in a sponge
+type spongeDB struct {
+ sponge hash.Hash
+ id string
+ journal []string
+}
+
+func (s *spongeDB) Has(key []byte) (bool, error) { panic("implement me") }
+func (s *spongeDB) Get(key []byte) ([]byte, bool, error) {
+ return nil, false, errors.New("no such elem")
+}
+func (s *spongeDB) Delete(key []byte) error { panic("implement me") }
+func (s *spongeDB) NewBatch() kvdb.Batch { return &spongeBatch{s} }
+func (s *spongeDB) NewSnapshot() (kvdb.Snapshot, error) { panic("implement me") }
+func (s *spongeDB) Stat(property string) (string, error) { panic("implement me") }
+func (s *spongeDB) Compact(start []byte, limit []byte) error { panic("implement me") }
+func (s *spongeDB) Close() error { return nil }
+func (s *spongeDB) Set(key []byte, value []byte) error {
+ valbrief := value
+ if len(valbrief) > 8 {
+ valbrief = valbrief[:8]
+ }
+ s.journal = append(s.journal, fmt.Sprintf("%v: Set([%x...], [%d bytes] %x...)\n", s.id, key[:8], len(value), valbrief))
+ s.sponge.Write(key)
+ s.sponge.Write(value)
+ return nil
+}
+func (s *spongeDB) NewIterator(prefix []byte, start []byte) kvdb.Iterator { panic("implement me") }
+
+// spongeBatch is a dummy batch which immediately writes to the underlying spongedb
+type spongeBatch struct {
+ db *spongeDB
+}
+
+func (b *spongeBatch) Set(key, value []byte) error {
+ b.db.Set(key, value)
+ return nil
+}
+func (b *spongeBatch) Delete(key []byte) error { panic("implement me") }
+func (b *spongeBatch) ValueSize() int { return 100 }
+func (b *spongeBatch) Write() error { return nil }
+func (b *spongeBatch) Reset() {}
+func (b *spongeBatch) Replay(w kvdb.KVWriter) error { return nil }
+
+// TestCommitSequence tests that the trie.Commit operation writes the elements of the trie
+// in the expected order, and calls the callbacks in the expected order.
+// The test data was based on the 'master' code, and is basically random. It can be used
+// to check whether changes to the trie modifies the write order or data in any way.
+func TestCommitSequence(t *testing.T) {
+ for i, tc := range []struct {
+ count int
+ expWriteSeqHash []byte
+ expCallbackSeqHash []byte
+ }{
+ {20, types.StringToBytes("873c78df73d60e59d4a2bcf3716e8bfe14554549fea2fc147cb54129382a8066"),
+ types.StringToBytes("ff00f91ac05df53b82d7f178d77ada54fd0dca64526f537034a5dbe41b17df2a")},
+ {200, types.StringToBytes("ba03d891bb15408c940eea5ee3d54d419595102648d02774a0268d892add9c8e"),
+ types.StringToBytes("f3cd509064c8d319bbdd1c68f511850a902ad275e6ed5bea11547e23d492a926")},
+ {2000, types.StringToBytes("f7a184f20df01c94f09537401d11e68d97ad0c00115233107f51b9c287ce60c7"),
+ types.StringToBytes("ff795ea898ba1e4cfed4a33b4cf5535a347a02cf931f88d88719faf810f9a1c9")},
+ } {
+ addresses, accounts := makeAccounts(tc.count)
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDB{sponge: sha3.NewLegacyKeccak256()}
+ db := NewDatabase(rawdb.NewDatabase(s), hclog.NewNullLogger())
+ trie := NewEmpty(db)
+ // Another sponge is used to check the callback-sequence
+ callbackSponge := sha3.NewLegacyKeccak256()
+ // Fill the trie with elements
+ for i := 0; i < tc.count; i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+ // Flush memdb -> disk (sponge)
+ db.Commit(root, false, func(c types.Hash) {
+ // And spongify the callback-order
+ callbackSponge.Write(c[:])
+ })
+ if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) {
+ t.Errorf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp)
+ }
+ if got, exp := callbackSponge.Sum(nil), tc.expCallbackSeqHash; !bytes.Equal(got, exp) {
+ t.Errorf("test %d, call back sequence wrong:\ngot: %x exp %x\n", i, got, exp)
+ }
+ }
+}
+
+// TestCommitSequenceRandomBlobs is identical to TestCommitSequence
+// but uses random blobs instead of 'accounts'
+func TestCommitSequenceRandomBlobs(t *testing.T) {
+ for i, tc := range []struct {
+ count int
+ expWriteSeqHash []byte
+ expCallbackSeqHash []byte
+ }{
+ {20, types.StringToBytes("8e4a01548551d139fa9e833ebc4e66fc1ba40a4b9b7259d80db32cff7b64ebbc"),
+ types.StringToBytes("450238d73bc36dc6cc6f926987e5428535e64be403877c4560e238a52749ba24")},
+ {200, types.StringToBytes("6869b4e7b95f3097a19ddb30ff735f922b915314047e041614df06958fc50554"),
+ types.StringToBytes("0ace0b03d6cb8c0b82f6289ef5b1a1838306b455a62dafc63cada8e2924f2550")},
+ {2000, types.StringToBytes("444200e6f4e2df49f77752f629a96ccf7445d4698c164f962bbd85a0526ef424"),
+ types.StringToBytes("117d30dafaa62a1eed498c3dfd70982b377ba2b46dd3e725ed6120c80829e518")},
+ } {
+ prng := rand.New(rand.NewSource(int64(i)))
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDB{sponge: sha3.NewLegacyKeccak256()}
+ db := NewDatabase(rawdb.NewDatabase(s), hclog.NewNullLogger())
+ trie := NewEmpty(db)
+ // Another sponge is used to check the callback-sequence
+ callbackSponge := sha3.NewLegacyKeccak256()
+ // Fill the trie with elements
+ for i := 0; i < tc.count; i++ {
+ key := make([]byte, 32)
+ var val []byte
+ // 50% short elements, 50% large elements
+ if prng.Intn(2) == 0 {
+ val = make([]byte, 1+prng.Intn(32))
+ } else {
+ val = make([]byte, 1+prng.Intn(4096))
+ }
+ prng.Read(key)
+ prng.Read(val)
+ trie.Update(key, val)
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ db.Update(NewWithNodeSet(nodes))
+ // Flush memdb -> disk (sponge)
+ db.Commit(root, false, func(c types.Hash) {
+ // And spongify the callback-order
+ callbackSponge.Write(c[:])
+ })
+ if got, exp := s.sponge.Sum(nil), tc.expWriteSeqHash; !bytes.Equal(got, exp) {
+ t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", i, got, exp)
+ }
+ if got, exp := callbackSponge.Sum(nil), tc.expCallbackSeqHash; !bytes.Equal(got, exp) {
+ t.Fatalf("test %d, call back sequence wrong:\ngot: %x exp %x\n", i, got, exp)
+ }
+ }
+}
+
+func TestCommitSequenceStackTrie(t *testing.T) {
+ for count := 1; count < 200; count++ {
+ prng := rand.New(rand.NewSource(int64(count)))
+ // This spongeDb is used to check the sequence of disk-db-writes
+ s := &spongeDB{sponge: sha3.NewLegacyKeccak256(), id: "a"}
+ db := NewDatabase(rawdb.NewDatabase(s), hclog.NewNullLogger())
+ trie := NewEmpty(db)
+ // Another sponge is used for the stacktrie commits
+ stackTrieSponge := &spongeDB{sponge: sha3.NewLegacyKeccak256(), id: "b"}
+ stTrie := NewStackTrie(func(owner types.Hash, path []byte, hash types.Hash, blob []byte) {
+ db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob)
+ })
+ // Fill the trie with elements
+ for i := 0; i < count; i++ {
+ // For the stack trie, we need to do inserts in proper order
+ key := make([]byte, 32)
+ binary.BigEndian.PutUint64(key, uint64(i))
+ var val []byte
+ // 50% short elements, 50% large elements
+ if prng.Intn(2) == 0 {
+ val = make([]byte, 1+prng.Intn(32))
+ } else {
+ val = make([]byte, 1+prng.Intn(1024))
+ }
+ prng.Read(val)
+ trie.TryUpdate(key, val)
+ stTrie.TryUpdate(key, val)
+ }
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ // Flush memdb -> disk (sponge)
+ db.Update(NewWithNodeSet(nodes))
+ db.Commit(root, false, nil)
+ // And flush stacktrie -> disk
+ stRoot, err := stTrie.Commit()
+ if err != nil {
+ t.Fatalf("Failed to commit stack trie %v", err)
+ }
+ if stRoot != root {
+ t.Fatalf("root wrong, got %x exp %x", stRoot, root)
+ }
+ if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) {
+ // Show the journal
+ t.Logf("Expected:")
+ for i, v := range s.journal {
+ t.Logf("op %d: %v", i, v)
+ }
+ t.Logf("Stacktrie:")
+ for i, v := range stackTrieSponge.journal {
+ t.Logf("op %d: %v", i, v)
+ }
+ t.Fatalf("test %d, disk write sequence wrong:\ngot %x exp %x\n", count, got, exp)
+ }
+ }
+}
+
+// TestCommitSequenceSmallRoot tests that a trie which is essentially only a
+// small (<32 byte) shortnode with an included value is properly committed to a
+// database.
+// This case might not matter, since in practice, all keys are 32 bytes, which means
+// that even a small trie which contains a leaf will have an extension making it
+// not fit into 32 bytes, rlp-encoded. However, it's still the correct thing to do.
+func TestCommitSequenceSmallRoot(t *testing.T) {
+ s := &spongeDB{sponge: sha3.NewLegacyKeccak256(), id: "a"}
+ db := NewDatabase(rawdb.NewDatabase(s), hclog.NewNullLogger())
+ trie := NewEmpty(db)
+ // Another sponge is used for the stacktrie commits
+ stackTrieSponge := &spongeDB{sponge: sha3.NewLegacyKeccak256(), id: "b"}
+ stTrie := NewStackTrie(func(owner types.Hash, path []byte, hash types.Hash, blob []byte) {
+ db.Scheme().WriteTrieNode(stackTrieSponge, owner, path, hash, blob)
+ })
+ // Add a single small-element to the trie(s)
+ key := make([]byte, 5)
+ key[0] = 1
+ trie.TryUpdate(key, []byte{0x1})
+ stTrie.TryUpdate(key, []byte{0x1})
+ // Flush trie -> database
+ root, nodes, _ := trie.Commit(false)
+ // Flush memdb -> disk (sponge)
+ db.Update(NewWithNodeSet(nodes))
+ db.Commit(root, false, nil)
+ // And flush stacktrie -> disk
+ stRoot, err := stTrie.Commit()
+ if err != nil {
+ t.Fatalf("Failed to commit stack trie %v", err)
+ }
+ if stRoot != root {
+ t.Fatalf("root wrong, got %x exp %x", stRoot, root)
+ }
+
+ t.Logf("root: %x\n", stRoot)
+ if got, exp := stackTrieSponge.sponge.Sum(nil), s.sponge.Sum(nil); !bytes.Equal(got, exp) {
+ t.Fatalf("test, disk write sequence wrong:\ngot %x exp %x\n", got, exp)
+ }
+}
+
+// BenchmarkCommitAfterHashFixedSize benchmarks the Commit (after Hash) of a fixed number of updates to a trie.
+// This benchmark is meant to capture the difference on efficiency of small versus large changes. Typically,
+// storage tries are small (a couple of entries), whereas the full post-block account trie update is large (a couple
+// of thousand entries)
+func BenchmarkHashFixedSize(b *testing.B) {
+ b.Run("10", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(20)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+
+ b.Run("1K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(1000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("10K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(10000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100000)
+ for i := 0; i < b.N; i++ {
+ benchmarkHashFixedSize(b, acc, add)
+ }
+ })
+}
+
+func benchmarkHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byte) {
+ b.Helper()
+
+ b.ReportAllocs()
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ b.StartTimer()
+ trie.Hash()
+ b.StopTimer()
+}
+
+func BenchmarkCommitAfterHashFixedSize(b *testing.B) {
+ b.Run("10", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(20)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+
+ b.Run("1K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(1000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("10K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(10000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100000)
+ for i := 0; i < b.N; i++ {
+ benchmarkCommitAfterHashFixedSize(b, acc, add)
+ }
+ })
+}
+
+func benchmarkCommitAfterHashFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byte) {
+ b.Helper()
+
+ b.ReportAllocs()
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ // Insert the accounts into the trie and hash it
+ trie.Hash()
+ b.StartTimer()
+ trie.Commit(false)
+ b.StopTimer()
+}
+
+func BenchmarkDerefRootFixedSize(b *testing.B) {
+ b.Run("10", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(20)
+ for i := 0; i < b.N; i++ {
+ benchmarkDerefRootFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100)
+ for i := 0; i < b.N; i++ {
+ benchmarkDerefRootFixedSize(b, acc, add)
+ }
+ })
+
+ b.Run("1K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(1000)
+ for i := 0; i < b.N; i++ {
+ benchmarkDerefRootFixedSize(b, acc, add)
+ }
+ })
+ b.Run("10K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(10000)
+ for i := 0; i < b.N; i++ {
+ benchmarkDerefRootFixedSize(b, acc, add)
+ }
+ })
+ b.Run("100K", func(b *testing.B) {
+ b.StopTimer()
+ acc, add := makeAccounts(100000)
+ for i := 0; i < b.N; i++ {
+ benchmarkDerefRootFixedSize(b, acc, add)
+ }
+ })
+}
+
+func benchmarkDerefRootFixedSize(b *testing.B, addresses [][20]byte, accounts [][]byte) {
+ b.Helper()
+
+ b.ReportAllocs()
+ triedb := NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger())
+ trie := NewEmpty(triedb)
+ for i := 0; i < len(addresses); i++ {
+ trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
+ }
+ h := trie.Hash()
+ _, nodes, _ := trie.Commit(false)
+ triedb.Update(NewWithNodeSet(nodes))
+ b.StartTimer()
+ triedb.Dereference(h)
+ b.StopTimer()
+}
+
+func getString(trie *Trie, k string) []byte {
+ return trie.Get([]byte(k))
+}
+
+func updateString(trie *Trie, k, v string) {
+ trie.Update([]byte(k), []byte(v))
+}
+
+func deleteString(trie *Trie, k string) {
+ trie.Delete([]byte(k))
+}
+
+func TestDecodeNode(t *testing.T) {
+ t.Parallel()
+ var (
+ hash = make([]byte, 20)
+ elems = make([]byte, 20)
+ )
+ for i := 0; i < 5000000; i++ {
+ rand.Read(hash)
+ rand.Read(elems)
+ decodeNode(hash, elems)
+ }
+}
diff --git a/trie/util_test.go b/trie/util_test.go
new file mode 100644
index 0000000000..d0c249e595
--- /dev/null
+++ b/trie/util_test.go
@@ -0,0 +1,353 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/dogechain-lab/dogechain/helper/rawdb"
+ "github.com/dogechain-lab/dogechain/types"
+ "github.com/hashicorp/go-hclog"
+)
+
+// Tests if the trie diffs are tracked correctly.
+func TestTrieTracer(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ db := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ trie.Hash()
+
+ seen := make(map[string]struct{})
+ it := trie.NodeIterator(nil)
+ for it.Next(true) {
+ if it.Leaf() {
+ continue
+ }
+ seen[string(it.Path())] = struct{}{}
+ }
+ inserted := trie.tracer.insertList()
+ if len(inserted) != len(seen) {
+ t.Fatalf("Unexpected inserted node tracked want %d got %d", len(seen), len(inserted))
+ }
+ for _, k := range inserted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+ deleted := trie.tracer.deleteList()
+ if len(deleted) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(deleted))
+ }
+
+ // Commit the changes and re-create with new root
+ root, nodes, _ := trie.Commit(false)
+ if err := db.Update(NewWithNodeSet(nodes)); err != nil {
+ t.Fatal(err)
+ }
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+
+ // Delete all the elements, check deletion set
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ trie.Hash()
+
+ inserted = trie.tracer.insertList()
+ if len(inserted) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(inserted))
+ }
+ deleted = trie.tracer.deleteList()
+ if len(deleted) != len(seen) {
+ t.Fatalf("Unexpected deleted node tracked want %d got %d", len(seen), len(deleted))
+ }
+ for _, k := range deleted {
+ _, ok := seen[string(k)]
+ if !ok {
+ t.Fatalf("Unexpected inserted node")
+ }
+ }
+}
+
+func TestTrieTracerNoop(t *testing.T) {
+ trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase(), hclog.NewNullLogger()))
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ if len(trie.tracer.insertList()) != 0 {
+ t.Fatalf("Unexpected inserted node tracked %d", len(trie.tracer.insertList()))
+ }
+ if len(trie.tracer.deleteList()) != 0 {
+ t.Fatalf("Unexpected deleted node tracked %d", len(trie.tracer.deleteList()))
+ }
+}
+
+func TestTrieTracePrevValue(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ db := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(db)
+ trie.tracer = newTracer()
+
+ paths, blobs := trie.tracer.prevList()
+ if len(paths) != 0 || len(blobs) != 0 {
+ t.Fatalf("Nothing should be tracked")
+ }
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ paths, blobs = trie.tracer.prevList()
+ if len(paths) != 0 || len(blobs) != 0 {
+ t.Fatalf("Nothing should be tracked")
+ }
+
+ // Commit the changes and re-create with new root
+ root, nodes, _ := trie.Commit(false)
+ if err := db.Update(NewWithNodeSet(nodes)); err != nil {
+ t.Fatal(err)
+ }
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+ trie.resolveAndTrack(root.Bytes(), nil)
+
+ // Load all nodes in trie
+ for _, val := range vals {
+ trie.TryGet([]byte(val.k))
+ }
+
+ // Ensure all nodes are tracked by tracer with correct prev-values
+ iter := trie.NodeIterator(nil)
+ seen := make(map[string][]byte)
+ for iter.Next(true) {
+ // Embedded nodes are ignored since they are not present in
+ // database.
+ if iter.Hash() == (types.Hash{}) {
+ continue
+ }
+ seen[string(iter.Path())] = types.CopyBytes(iter.NodeBlob())
+ }
+
+ paths, blobs = trie.tracer.prevList()
+ if len(paths) != len(seen) || len(blobs) != len(seen) {
+ t.Fatalf("Unexpected tracked values")
+ }
+ for i, path := range paths {
+ blob := blobs[i]
+ prev, ok := seen[string(path)]
+ if !ok {
+ t.Fatalf("Missing node %v", path)
+ }
+ if !bytes.Equal(blob, prev) {
+ t.Fatalf("Unexpected value path: %v, want: %v, got: %v", path, prev, blob)
+ }
+ }
+
+ // Re-open the trie and iterate the trie, ensure nothing will be tracked.
+ // Iterator will not link any loaded nodes to trie.
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+
+ iter = trie.NodeIterator(nil)
+ for iter.Next(true) {
+ }
+ paths, blobs = trie.tracer.prevList()
+ if len(paths) != 0 || len(blobs) != 0 {
+ t.Fatalf("Nothing should be tracked")
+ }
+
+ // Re-open the trie and generate proof for entries, ensure nothing will
+ // be tracked. Prover will not link any loaded nodes to trie.
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+ for _, val := range vals {
+ trie.Prove([]byte(val.k), 0, rawdb.NewMemoryDatabase())
+ }
+ paths, blobs = trie.tracer.prevList()
+ if len(paths) != 0 || len(blobs) != 0 {
+ t.Fatalf("Nothing should be tracked")
+ }
+
+ // Delete entries from trie, ensure all previous values are correct.
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+ trie.resolveAndTrack(root.Bytes(), nil)
+
+ for _, val := range vals {
+ trie.TryDelete([]byte(val.k))
+ }
+ paths, blobs = trie.tracer.prevList()
+ if len(paths) != len(seen) || len(blobs) != len(seen) {
+ t.Fatalf("Unexpected tracked values")
+ }
+ for i, path := range paths {
+ blob := blobs[i]
+ prev, ok := seen[string(path)]
+ if !ok {
+ t.Fatalf("Missing node %v", path)
+ }
+ if !bytes.Equal(blob, prev) {
+ t.Fatalf("Unexpected value path: %v, want: %v, got: %v", path, prev, blob)
+ }
+ }
+}
+
+func TestDeleteAll(t *testing.T) {
+ logger := hclog.NewNullLogger()
+ db := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie := NewEmpty(db)
+ trie.tracer = newTracer()
+
+ // Insert a batch of entries, all the nodes should be marked as inserted
+ vals := []struct{ k, v string }{
+ {"do", "verb"},
+ {"ether", "wookiedoo"},
+ {"horse", "stallion"},
+ {"shaman", "horse"},
+ {"doge", "coin"},
+ {"dog", "puppy"},
+ {"somethingveryoddindeedthis is", "myothernodedata"},
+ }
+ for _, val := range vals {
+ trie.Update([]byte(val.k), []byte(val.v))
+ }
+ root, set, err := trie.Commit(false)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := db.Update(NewWithNodeSet(set)); err != nil {
+ t.Fatal(err)
+ }
+ // Delete entries from trie, ensure all values are detected
+ trie, _ = New(TrieID(root), db, logger)
+ trie.tracer = newTracer()
+ trie.resolveAndTrack(root.Bytes(), nil)
+
+ // Iterate all existent nodes
+ var (
+ it = trie.NodeIterator(nil)
+ nodes = make(map[string][]byte)
+ )
+ for it.Next(true) {
+ if it.Hash() != (types.Hash{}) {
+ nodes[string(it.Path())] = types.CopyBytes(it.NodeBlob())
+ }
+ }
+
+ // Perform deletion to purge the entire trie
+ for _, val := range vals {
+ trie.Delete([]byte(val.k))
+ }
+ root, set, err = trie.Commit(false)
+ if err != nil {
+ t.Fatalf("Failed to delete trie %v", err)
+ }
+ if root != types.EmptyRootHash {
+ t.Fatalf("Invalid trie root %v", root)
+ }
+ for path, blob := range set.deletes {
+ prev, ok := nodes[path]
+ if !ok {
+ t.Fatalf("Extra node deleted %v", []byte(path))
+ }
+ if !bytes.Equal(prev, blob) {
+ t.Fatalf("Unexpected previous value %v", []byte(path))
+ }
+ }
+ if len(set.deletes) != len(nodes) {
+ t.Fatalf("Unexpected deletion set")
+ }
+}
+
+// makeTestTrie create a sample test trie to test node-wise reconstruction.
+func makeTestTrie() (*Database, *StateTrie, map[string][]byte) {
+ // Create an empty trie
+ logger := hclog.NewNullLogger()
+ triedb := NewDatabase(rawdb.NewMemoryDatabase(), logger)
+ trie, _ := NewStateTrie(TrieID(types.Hash{}), triedb, logger)
+
+ // Fill it with some arbitrary data
+ content := make(map[string][]byte)
+ for i := byte(0); i < 255; i++ {
+ // Map the same data under multiple keys
+ key, val := types.LeftPadBytes([]byte{1, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ key, val = types.LeftPadBytes([]byte{2, i}, 32), []byte{i}
+ content[string(key)] = val
+ trie.Update(key, val)
+
+ // Add some other data to inflate the trie
+ for j := byte(3); j < 13; j++ {
+ key, val = types.LeftPadBytes([]byte{j, i}, 32), []byte{j, i}
+ content[string(key)] = val
+ trie.Update(key, val)
+ }
+ }
+ root, nodes, err := trie.Commit(false)
+ if err != nil {
+ panic(fmt.Errorf("failed to commit trie %v", err))
+ }
+ if err := triedb.Update(NewWithNodeSet(nodes)); err != nil {
+ panic(fmt.Errorf("failed to commit db %v", err))
+ }
+ // Re-create the trie based on the new state
+ trie, _ = NewStateTrie(TrieID(root), triedb, logger)
+ return triedb, trie, content
+}
diff --git a/trie/utils.go b/trie/utils.go
new file mode 100644
index 0000000000..13532bacd9
--- /dev/null
+++ b/trie/utils.go
@@ -0,0 +1,223 @@
+// Copyright 2022 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package trie
+
+// tracer tracks the changes of trie nodes. During the trie operations,
+// some nodes can be deleted from the trie, while these deleted nodes
+// won't be captured by trie.Hasher or trie.Committer. Thus, these deleted
+// nodes won't be removed from the disk at all. Tracer is an auxiliary tool
+// used to track all insert and delete operations of trie and capture all
+// deleted nodes eventually.
+//
+// The changed nodes can be mainly divided into two categories: the leaf
+// node and intermediate node. The former is inserted/deleted by callers
+// while the latter is inserted/deleted in order to follow the rule of trie.
+// This tool can track all of them no matter the node is embedded in its
+// parent or not, but valueNode is never tracked.
+//
+// Besides, it's also used for recording the original value of the nodes
+// when they are resolved from the disk. The pre-value of the nodes will
+// be used to construct reverse-diffs in the future.
+//
+// Note tracer is not thread-safe, callers should be responsible for handling
+// the concurrency issues by themselves.
+type tracer struct {
+ insert map[string]struct{}
+ delete map[string]struct{}
+ origin map[string][]byte
+}
+
+// newTracer initializes the tracer for capturing trie changes.
+func newTracer() *tracer {
+ return &tracer{
+ insert: make(map[string]struct{}),
+ delete: make(map[string]struct{}),
+ origin: make(map[string][]byte),
+ }
+}
+
+// onRead tracks the newly loaded trie node and caches the rlp-encoded blob internally.
+// Don't change the value outside of function since it's not deep-copied.
+func (t *tracer) onRead(path []byte, val []byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+
+ t.origin[string(path)] = val
+}
+
+// onInsert tracks the newly inserted trie node. If it's already in the deletion set
+// (resurrected node), then just wipe it from the deletion set as the "untouched".
+func (t *tracer) onInsert(path []byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+
+ if _, present := t.delete[string(path)]; present {
+ delete(t.delete, string(path))
+
+ return
+ }
+
+ t.insert[string(path)] = struct{}{}
+}
+
+// onDelete tracks the newly deleted trie node. If it's already
+// in the addition set, then just wipe it from the addition set
+// as it's untouched.
+func (t *tracer) onDelete(path []byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+
+ if _, present := t.insert[string(path)]; present {
+ delete(t.insert, string(path))
+
+ return
+ }
+
+ t.delete[string(path)] = struct{}{}
+}
+
+// insertList returns the tracked inserted trie nodes in list format.
+func (t *tracer) insertList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+
+ ret := make([][]byte, 0, len(t.insert))
+ for path := range t.insert {
+ ret = append(ret, []byte(path))
+ }
+
+ return ret
+}
+
+// deleteList returns the tracked deleted trie nodes in list format.
+func (t *tracer) deleteList() [][]byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+
+ var ret = make([][]byte, 0, len(t.delete))
+
+ for path := range t.delete {
+ ret = append(ret, []byte(path))
+ }
+
+ return ret
+}
+
+// prevList returns the tracked node blobs in list format.
+func (t *tracer) prevList() ([][]byte, [][]byte) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil, nil
+ }
+
+ var (
+ paths = make([][]byte, 0, len(t.origin))
+ blobs = make([][]byte, 0, len(t.origin))
+ )
+
+ for path, blob := range t.origin {
+ paths = append(paths, []byte(path))
+ blobs = append(blobs, blob)
+ }
+
+ return paths, blobs
+}
+
+// getPrev returns the cached original value of the specified node.
+func (t *tracer) getPrev(path []byte) []byte {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+
+ return t.origin[string(path)]
+}
+
+// reset clears the content tracked by tracer.
+func (t *tracer) reset() {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+
+ t.insert = make(map[string]struct{})
+ t.delete = make(map[string]struct{})
+ t.origin = make(map[string][]byte)
+}
+
+// copy returns a deep copied tracer instance.
+func (t *tracer) copy() *tracer {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return nil
+ }
+
+ var (
+ insert = make(map[string]struct{})
+ needDelete = make(map[string]struct{})
+ origin = make(map[string][]byte)
+ )
+
+ for key := range t.insert {
+ insert[key] = struct{}{}
+ }
+
+ for key := range t.delete {
+ needDelete[key] = struct{}{}
+ }
+
+ for key, val := range t.origin {
+ origin[key] = val
+ }
+
+ return &tracer{
+ insert: insert,
+ delete: needDelete,
+ origin: origin,
+ }
+}
+
+// markDeletions puts all tracked deletions into the provided nodeset.
+func (t *tracer) markDeletions(set *NodeSet) {
+ // Tracer isn't used right now, remove this check later.
+ if t == nil {
+ return
+ }
+
+ for _, path := range t.deleteList() {
+ // There are a few possibilities for this scenario(the node is deleted
+ // but not present in database previously), for example the node was
+ // embedded in the parent and now deleted from the trie. In this case
+ // it's noop from database's perspective.
+ val := t.getPrev(path)
+ if len(val) == 0 {
+ continue
+ }
+
+ set.markDeleted(path, val)
+ }
+}
diff --git a/types/big.go b/types/big.go
new file mode 100644
index 0000000000..d13957efc0
--- /dev/null
+++ b/types/big.go
@@ -0,0 +1,8 @@
+package types
+
+import "math/big"
+
+// Common big integers often used
+var (
+ Big1 = big.NewInt(1)
+)
diff --git a/types/bytes.go b/types/bytes.go
new file mode 100644
index 0000000000..056210bf9e
--- /dev/null
+++ b/types/bytes.go
@@ -0,0 +1,77 @@
+package types
+
+import (
+ "encoding/hex"
+ "strings"
+)
+
+// CopyBytes returns an exact copy of the provided bytes.
+func CopyBytes(b []byte) (copiedBytes []byte) {
+ if b == nil {
+ return nil
+ }
+
+ copiedBytes = make([]byte, len(b))
+ copy(copiedBytes, b)
+
+ return
+}
+
+func StringToBytes(str string) []byte {
+ str = strings.TrimPrefix(str, "0x")
+ if len(str)%2 == 1 {
+ str = "0" + str
+ }
+
+ b, _ := hex.DecodeString(str)
+
+ return b
+}
+
+// RightPadBytes zero-pads slice to the right up to length l.
+func RightPadBytes(slice []byte, l int) []byte {
+ if l <= len(slice) {
+ return slice
+ }
+
+ padded := make([]byte, l)
+ copy(padded, slice)
+
+ return padded
+}
+
+// LeftPadBytes zero-pads slice to the left up to length l.
+func LeftPadBytes(slice []byte, l int) []byte {
+ if l <= len(slice) {
+ return slice
+ }
+
+ padded := make([]byte, l)
+ copy(padded[l-len(slice):], slice)
+
+ return padded
+}
+
+// TrimLeftZeroes returns a subslice of s without leading zeroes
+func TrimLeftZeroes(s []byte) []byte {
+ idx := 0
+ for ; idx < len(s); idx++ {
+ if s[idx] != 0 {
+ break
+ }
+ }
+
+ return s[idx:]
+}
+
+// TrimRightZeroes returns a subslice of s without trailing zeroes
+func TrimRightZeroes(s []byte) []byte {
+ idx := len(s)
+ for ; idx > 0; idx-- {
+ if s[idx-1] != 0 {
+ break
+ }
+ }
+
+ return s[:idx]
+}
diff --git a/types/bytes_test.go b/types/bytes_test.go
new file mode 100644
index 0000000000..f664dd751c
--- /dev/null
+++ b/types/bytes_test.go
@@ -0,0 +1,86 @@
+package types
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestCopyBytes(t *testing.T) {
+ input := []byte{1, 2, 3, 4}
+
+ v := CopyBytes(input)
+ if !bytes.Equal(v, []byte{1, 2, 3, 4}) {
+ t.Fatal("not equal after copy")
+ }
+
+ v[0] = 99
+
+ if bytes.Equal(v, input) {
+ t.Fatal("result is not a copy")
+ }
+}
+
+func TestStringToBytes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ arr []byte
+ exp []byte
+ }{
+ {StringToBytes("0x00ffff00ff0000"), []byte{0x00, 0xff, 0xff, 0x00, 0xff, 0x00, 0x00}},
+ {StringToBytes("0x00000000000000"), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {StringToBytes("0xff"), []byte{0xff}},
+ {[]byte{}, []byte{}},
+ {StringToBytes("0x00ffffffffffff"), []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ }
+
+ for i, test := range tests {
+ if !bytes.Equal(test.arr, test.exp) {
+ t.Errorf("test %d, got %x exp %x", i, test.arr, test.exp)
+ }
+ }
+}
+
+func TestTrimLeftZeroes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ arr []byte
+ exp []byte
+ }{
+ {StringToBytes("0x00ffff00ff0000"), StringToBytes("0xffff00ff0000")},
+ {StringToBytes("0x00000000000000"), []byte{}},
+ {StringToBytes("0xff"), StringToBytes("0xff")},
+ {[]byte{}, []byte{}},
+ {StringToBytes("0x00ffffffffffff"), StringToBytes("0xffffffffffff")},
+ }
+
+ for i, test := range tests {
+ got := TrimLeftZeroes(test.arr)
+ if !bytes.Equal(got, test.exp) {
+ t.Errorf("test %d, got %x exp %x", i, got, test.exp)
+ }
+ }
+}
+
+func TestTrimRightZeroes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ arr []byte
+ exp []byte
+ }{
+ {StringToBytes("0x00ffff00ff0000"), StringToBytes("0x00ffff00ff")},
+ {StringToBytes("0x00000000000000"), []byte{}},
+ {StringToBytes("0xff"), StringToBytes("0xff")},
+ {[]byte{}, []byte{}},
+ {StringToBytes("0x00ffffffffffff"), StringToBytes("0x00ffffffffffff")},
+ }
+
+ for i, test := range tests {
+ got := TrimRightZeroes(test.arr)
+ if !bytes.Equal(got, test.exp) {
+ t.Errorf("test %d, got %x exp %x", i, got, test.exp)
+ }
+ }
+}
diff --git a/types/format.go b/types/format.go
new file mode 100644
index 0000000000..7c996da30e
--- /dev/null
+++ b/types/format.go
@@ -0,0 +1,84 @@
+// Copyright 2016 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types
+
+import (
+ "fmt"
+ "regexp"
+ "strings"
+ "time"
+)
+
+// PrettyDuration is a pretty printed version of a time.Duration value that cuts
+// the unnecessary precision off from the formatted textual representation.
+type PrettyDuration time.Duration
+
+var prettyDurationRe = regexp.MustCompile(`\.[0-9]{4,}`)
+
+// String implements the Stringer interface, allowing pretty printing of duration
+// values rounded to three decimals.
+func (d PrettyDuration) String() string {
+ label := time.Duration(d).String()
+ if match := prettyDurationRe.FindString(label); len(match) > 4 {
+ label = strings.Replace(label, match, match[:4], 1)
+ }
+
+ return label
+}
+
+// PrettyAge is a pretty printed version of a time.Duration value that rounds
+// the values up to a single most significant unit, days/weeks/years included.
+type PrettyAge time.Time
+
+// ageUnits is a list of units the age pretty printing uses.
+var ageUnits = []struct {
+ Size time.Duration
+ Symbol string
+}{
+ {12 * 30 * 24 * time.Hour, "y"},
+ {30 * 24 * time.Hour, "mo"},
+ {7 * 24 * time.Hour, "w"},
+ {24 * time.Hour, "d"},
+ {time.Hour, "h"},
+ {time.Minute, "m"},
+ {time.Second, "s"},
+}
+
+// String implements the Stringer interface, allowing pretty printing of duration
+// values rounded to the most significant time unit.
+func (t PrettyAge) String() string {
+ // Calculate the time difference and handle the 0 cornercase
+ diff := time.Since(time.Time(t))
+ if diff < time.Second {
+ return "0"
+ }
+ // Accumulate a precision of 3 components before returning
+ result, prec := "", 0
+
+ for _, unit := range ageUnits {
+ if diff > unit.Size {
+ result = fmt.Sprintf("%s%d%s", result, diff/unit.Size, unit.Symbol)
+ diff %= unit.Size
+
+ if prec += 1; prec >= 3 {
+ break
+ }
+ }
+ }
+
+ return result
+}
diff --git a/types/hash.go b/types/hash.go
index afda5c3028..0059f311c8 100644
--- a/types/hash.go
+++ b/types/hash.go
@@ -10,13 +10,18 @@ import (
var (
// ZeroHash is all zero hash
+ // It is used as a constant for comparison.
+ // Do not return its slice, otherwise it might be overwritten.
ZeroHash = Hash{}
// EmptyRootHash is the root when there are no transactions
EmptyRootHash = StringToHash("0x56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
- // EmptyUncleHash is the root when there are no uncles
+ // EmptyUncleHash is the root when there are no uncles. value of crypto.Keccak256([]*Header(nil))
EmptyUncleHash = StringToHash("0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347")
+
+ // EmptyCode is the known hash of the empty EVM bytecode. value of crypto.Keccak256(nil)
+ EmptyCodeHash = StringToHash("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470")
)
const (
diff --git a/types/rlp_unmarshal.go b/types/rlp_unmarshal.go
index f3a5953f52..3db8e84c25 100644
--- a/types/rlp_unmarshal.go
+++ b/types/rlp_unmarshal.go
@@ -14,26 +14,26 @@ type RLPUnmarshaler interface {
type unmarshalRLPFunc func(p *fastrlp.Parser, v *fastrlp.Value) error
func UnmarshalRlp(obj unmarshalRLPFunc, input []byte) error {
- pr := fastrlp.DefaultParserPool.Get()
+ var pr fastrlp.Parser
v, err := pr.Parse(input)
if err != nil {
- fastrlp.DefaultParserPool.Put(pr)
-
return err
}
- if err := obj(pr, v); err != nil {
- fastrlp.DefaultParserPool.Put(pr)
-
+ if err := obj(&pr, v); err != nil {
return err
}
- fastrlp.DefaultParserPool.Put(pr)
-
return nil
}
+func RlpUnmarshal(input []byte) (*fastrlp.Value, error) {
+ var p fastrlp.Parser
+
+ return p.Parse(input)
+}
+
func (b *Block) UnmarshalRLP(input []byte) error {
return UnmarshalRlp(b.UnmarshalRLPFrom, input)
}
diff --git a/types/size.go b/types/size.go
new file mode 100644
index 0000000000..45ad87d223
--- /dev/null
+++ b/types/size.go
@@ -0,0 +1,56 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types
+
+import (
+ "fmt"
+)
+
+// StorageSize is a wrapper around a float value that supports user friendly
+// formatting.
+type StorageSize float64
+
+// String implements the stringer interface.
+func (s StorageSize) String() string {
+ if s > 1099511627776 {
+ return fmt.Sprintf("%.2f TiB", s/1099511627776)
+ } else if s > 1073741824 {
+ return fmt.Sprintf("%.2f GiB", s/1073741824)
+ } else if s > 1048576 {
+ return fmt.Sprintf("%.2f MiB", s/1048576)
+ } else if s > 1024 {
+ return fmt.Sprintf("%.2f KiB", s/1024)
+ } else {
+ return fmt.Sprintf("%.2f B", s)
+ }
+}
+
+// TerminalString implements log.TerminalStringer, formatting a string for console
+// output during logging.
+func (s StorageSize) TerminalString() string {
+ if s > 1099511627776 {
+ return fmt.Sprintf("%.2fTiB", s/1099511627776)
+ } else if s > 1073741824 {
+ return fmt.Sprintf("%.2fGiB", s/1073741824)
+ } else if s > 1048576 {
+ return fmt.Sprintf("%.2fMiB", s/1048576)
+ } else if s > 1024 {
+ return fmt.Sprintf("%.2fKiB", s/1024)
+ } else {
+ return fmt.Sprintf("%.2fB", s)
+ }
+}
diff --git a/types/size_test.go b/types/size_test.go
new file mode 100644
index 0000000000..3e39e4c7df
--- /dev/null
+++ b/types/size_test.go
@@ -0,0 +1,59 @@
+// Copyright 2014 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see .
+
+package types
+
+import (
+ "testing"
+)
+
+func TestStorageSizeString(t *testing.T) {
+ tests := []struct {
+ size StorageSize
+ str string
+ }{
+ {2839274474874, "2.58 TiB"},
+ {2458492810, "2.29 GiB"},
+ {2381273, "2.27 MiB"},
+ {2192, "2.14 KiB"},
+ {12, "12.00 B"},
+ }
+
+ for _, test := range tests {
+ if test.size.String() != test.str {
+ t.Errorf("%f: got %q, want %q", float64(test.size), test.size.String(), test.str)
+ }
+ }
+}
+
+func TestStorageSizeTerminalString(t *testing.T) {
+ tests := []struct {
+ size StorageSize
+ str string
+ }{
+ {2839274474874, "2.58TiB"},
+ {2458492810, "2.29GiB"},
+ {2381273, "2.27MiB"},
+ {2192, "2.14KiB"},
+ {12, "12.00B"},
+ }
+
+ for _, test := range tests {
+ if test.size.TerminalString() != test.str {
+ t.Errorf("%f: got %q, want %q", float64(test.size), test.size.TerminalString(), test.str)
+ }
+ }
+}
diff --git a/types/util.go b/types/util.go
index aea534dbe7..0d626e7007 100644
--- a/types/util.go
+++ b/types/util.go
@@ -1,11 +1,5 @@
package types
-import (
- "strings"
-
- "github.com/dogechain-lab/dogechain/helper/hex"
-)
-
func min(i, j int) int {
if i < j {
return i
@@ -13,14 +7,3 @@ func min(i, j int) int {
return j
}
-
-func StringToBytes(str string) []byte {
- str = strings.TrimPrefix(str, "0x")
- if len(str)%2 == 1 {
- str = "0" + str
- }
-
- b, _ := hex.DecodeString(str)
-
- return b
-}