diff --git a/trie/hasher.go b/trie/hasher.go index 54f6a9de2b6a..4dc706209cf2 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -26,8 +26,8 @@ import ( ) type hasher struct { - tmp sliceBuffer - sha keccakState + tmp []sliceBuffer + sha []keccakState onleaf LeafCallback } @@ -54,8 +54,18 @@ func (b *sliceBuffer) Reset() { var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ - tmp: make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + tmp: []sliceBuffer{ + make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. + make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. + make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. + make(sliceBuffer, 0, 550), // cap is as large as a full fullNode. + }, + sha: []keccakState{ + sha3.NewLegacyKeccak256().(keccakState), + sha3.NewLegacyKeccak256().(keccakState), + sha3.NewLegacyKeccak256().(keccakState), + sha3.NewLegacyKeccak256().(keccakState), + }, } }, } @@ -73,6 +83,9 @@ func returnHasherToPool(h *hasher) { // hash collapses a node down into a hash node, also returning a copy of the // original node initialized with the computed hash to replace the original one. func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) { + return h.hashParalell(n, db, force, 0) +} +func (h *hasher) hashParalell(n node, db *Database, force bool, id int) (node, node, error) { // If we're not storing the node, just hashing, use available cached data if hash, dirty := n.cache(); hash != nil { if db == nil { @@ -88,11 +101,14 @@ func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) { } } // Trie not processed yet or needs storage, walk the children - collapsed, cached, err := h.hashChildren(n, db) + collapsed, cached, err := h.hashChildrenParalell(n, db, id) if err != nil { return hashNode{}, n, err } - hashed, err := h.store(collapsed, db, force) + if id == -1 { + id = 0 + } + hashed, err := h.store(collapsed, db, force, id) if err != nil { return hashNode{}, n, err } @@ -119,6 +135,10 @@ func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) { // size of the child is larger than a hash, returning the collapsed node as well // as a replacement for the original node with the child hashes cached in. func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { + return h.hashChildrenParalell(original, db, 0) +} + +func (h *hasher) hashChildrenParalell(original node, db *Database, id int) (node, node, error) { var err error switch n := original.(type) { @@ -129,7 +149,7 @@ func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { cached.Key = common.CopyBytes(n.Key) if _, ok := n.Val.(valueNode); !ok { - collapsed.Val, cached.Val, err = h.hash(n.Val, db, false) + collapsed.Val, cached.Val, err = h.hashParalell(n.Val, db, false, id) if err != nil { return original, original, err } @@ -139,12 +159,74 @@ func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { case *fullNode: // Hash the full node's children, caching the newly hashed subtrees collapsed, cached := n.copy(), n.copy() + if id == -1 { // Top level, thread out + var wg sync.WaitGroup + wg.Add(3) + var e1, e2, e3, e4 error + go func() { + for i := 0; i < 4; i++ { + if n.Children[i] != nil { + collapsed.Children[i], cached.Children[i], e1 = h.hashParalell(n.Children[i], db, false, 0) + if err != nil { + return + } + } + } + wg.Done() + }() + go func() { + for i := 4; i < 8; i++ { + if n.Children[i] != nil { + collapsed.Children[i], cached.Children[i], e2 = h.hashParalell(n.Children[i], db, false, 1) + if err != nil { + return + } + } + } + wg.Done() + }() + go func() { + for i := 8; i < 12; i++ { + if n.Children[i] != nil { + collapsed.Children[i], cached.Children[i], e3 = h.hashParalell(n.Children[i], db, false, 2) + if err != nil { + return + } + } + } + wg.Done() + }() + for i := 12; i < 16; i++ { + if n.Children[i] != nil { + collapsed.Children[i], cached.Children[i], e4 = h.hashParalell(n.Children[i], db, false, 3) + if err != nil { + break + } + } + } + wg.Wait() + if e1 != nil { + return original, original, e1 + } + if e2 != nil { + return original, original, e2 + } + + if e3 != nil { + return original, original, e3 + } + + if e4 != nil { + return original, original, e4 + } - for i := 0; i < 16; i++ { - if n.Children[i] != nil { - collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false) - if err != nil { - return original, original, err + } else { + for i := 0; i < 16; i++ { + if n.Children[i] != nil { + collapsed.Children[i], cached.Children[i], err = h.hashParalell(n.Children[i], db, false, id) + if err != nil { + return original, original, err + } } } } @@ -160,31 +242,31 @@ func (h *hasher) hashChildren(original node, db *Database) (node, node, error) { // store hashes the node n and if we have a storage layer specified, it writes // the key/value pair to it and tracks any node->child references as well as any // node->external trie references. -func (h *hasher) store(n node, db *Database, force bool) (node, error) { +func (h *hasher) store(n node, db *Database, force bool, id int) (node, error) { // Don't store hashes or empty nodes. if _, isHash := n.(hashNode); n == nil || isHash { return n, nil } - // Generate the RLP encoding of the node - h.tmp.Reset() - if err := rlp.Encode(&h.tmp, n); err != nil { - panic("encode error: " + err.Error()) - } - if len(h.tmp) < 32 && !force { - return n, nil // Nodes smaller than 32 bytes are stored inside their parent - } - // Larger nodes are replaced by their hash and stored in the database. + // We might already have the hash hash, _ := n.cache() if hash == nil { - hash = h.makeHashNode(h.tmp) + // Generate the RLP encoding of the node + h.tmp[id].Reset() + if err := rlp.Encode(&h.tmp[id], n); err != nil { + panic("encode error: " + err.Error()) + } + if len(h.tmp[id]) < 32 && !force { + return n, nil // Nodes smaller than 32 bytes are stored inside their parent + } + // Larger nodes are replaced by their hash and stored in the database. + hash = h.makeHashNode(id) } - if db != nil { // We are pooling the trie nodes into an intermediate memory cache hash := common.BytesToHash(hash) db.lock.Lock() - db.insert(hash, h.tmp, n) + db.insert(hash, h.tmp[id], n) db.lock.Unlock() // Track external references from account->storage trie @@ -206,10 +288,10 @@ func (h *hasher) store(n node, db *Database, force bool) (node, error) { return hash, nil } -func (h *hasher) makeHashNode(data []byte) hashNode { - n := make(hashNode, h.sha.Size()) - h.sha.Reset() - h.sha.Write(data) - h.sha.Read(n) +func (h *hasher) makeHashNode(id int) hashNode { + n := make(hashNode, h.sha[id].Size()) + h.sha[id].Reset() + h.sha[id].Write(h.tmp[id]) + h.sha[id].Read(n) return n } diff --git a/trie/iterator.go b/trie/iterator.go index 8e84dee3b617..02fc645d0d3a 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -190,7 +190,7 @@ func (it *nodeIterator) LeafProof() [][]byte { for i, item := range it.stack[:len(it.stack)-1] { // Gather nodes that end up as hash nodes (or the root) node, _, _ := hasher.hashChildren(item.node, nil) - hashed, _ := hasher.store(node, nil, false) + hashed, _ := hasher.store(node, nil, false,0) if _, ok := hashed.(hashNode); ok || i == 0 { enc, _ := rlp.EncodeToBytes(node) proofs = append(proofs, enc) diff --git a/trie/proof.go b/trie/proof.go index 9985e730dd37..3f5845b41a4d 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -71,7 +71,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. n, _, _ = hasher.hashChildren(n, nil) - hn, _ := hasher.store(n, nil, false) + hn, _ := hasher.store(n, nil, false, 0) if hash, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the // root node), it becomes a proof element. @@ -80,7 +80,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e } else { enc, _ := rlp.EncodeToBytes(n) if !ok { - hash = hasher.makeHashNode(enc) + hash = hasher.makeHashNode(0) } proofDb.Put(hash, enc) } diff --git a/trie/secure_trie.go b/trie/secure_trie.go index fbc591ed108a..41a7d3b1e801 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -177,9 +177,9 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator { // invalid on the next call to hashKey or secKey. func (t *SecureTrie) hashKey(key []byte) []byte { h := newHasher(nil) - h.sha.Reset() - h.sha.Write(key) - buf := h.sha.Sum(t.hashKeyBuf[:0]) + h.sha[0].Reset() + h.sha[0].Write(key) + buf := h.sha[0].Sum(t.hashKeyBuf[:0]) returnHasherToPool(h) return buf } diff --git a/trie/trie.go b/trie/trie.go index 920e331fd62f..9d0c489d3b42 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -429,5 +429,5 @@ func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) { } h := newHasher(onleaf) defer returnHasherToPool(h) - return h.hash(t.root, db, true) + return h.hashParalell(t.root, db, true, -1) } diff --git a/trie/trie_test.go b/trie/trie_test.go index e53ac568e9c3..70f293643b9e 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -512,6 +512,72 @@ func BenchmarkHash(b *testing.B) { trie.Hash() } +// Benchmarks the trie Commit following a Hash. Since the trie caches the result of any operation, +// we cannot use b.N as the number of hashing rouns, since all rounds apart from +// the first one will be NOOP. As such, we'll use b.N as the number of account to +// insert into the trie before measuring the hashing. +func BenchmarkCommitAfterHash(b *testing.B) { + // Make the random benchmark deterministic + random := rand.New(rand.NewSource(0)) + + // Create a realistic account trie to hash + addresses := make([][20]byte, b.N) + for i := 0; i < len(addresses); i++ { + for j := 0; j < len(addresses[i]); j++ { + addresses[i][j] = byte(random.Intn(256)) + } + } + accounts := make([][]byte, len(addresses)) + for i := 0; i < len(accounts); i++ { + var ( + nonce = uint64(random.Int63()) + balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + root = emptyRoot + code = crypto.Keccak256(nil) + ) + accounts[i], _ = rlp.EncodeToBytes([]interface{}{nonce, balance, root, code}) + } + // Insert the accounts into the trie and hash it + trie := newEmpty() + for i := 0; i < len(addresses); i++ { + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + } + trie.Hash() + b.ResetTimer() + b.ReportAllocs() + trie.Commit(nil) +} + +func TestCommitAfterHash(t *testing.T) { + // Make the random benchmark deterministic + random := rand.New(rand.NewSource(0)) + + // Create a realistic account trie to hash + addresses := make([][20]byte, 10000) + for i := 0; i < len(addresses); i++ { + for j := 0; j < len(addresses[i]); j++ { + addresses[i][j] = byte(random.Intn(256)) + } + } + accounts := make([][]byte, len(addresses)) + for i := 0; i < len(accounts); i++ { + var ( + nonce = uint64(random.Int63()) + balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + root = emptyRoot + code = crypto.Keccak256(nil) + ) + accounts[i], _ = rlp.EncodeToBytes([]interface{}{nonce, balance, root, code}) + } + // Insert the accounts into the trie and hash it + trie := newEmpty() + for i := 0; i < len(addresses); i++ { + trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i]) + } + trie.Hash() + trie.Commit(nil) +} + func tempDB() (string, *Database) { dir, err := ioutil.TempDir("", "trie-bench") if err != nil {