Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions trie/stacktrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
package trie

import (
"bufio"
"bytes"
"encoding/gob"
"errors"
"fmt"
"io"
"sync"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -66,6 +70,96 @@ func NewStackTrie(db ethdb.KeyValueWriter) *StackTrie {
}
}

// NewFromBinary initialises a serialized stacktrie with the given db.
func NewFromBinary(data []byte, db ethdb.KeyValueWriter) (*StackTrie, error) {
var st StackTrie
if err := st.UnmarshalBinary(data); err != nil {
return nil, err
}
// If a database is used, we need to recursively add it to every child
if db != nil {
st.setDb(db)
}
return &st, nil
}

// MarshalBinary implements encoding.BinaryMarshaler
func (st *StackTrie) MarshalBinary() (data []byte, err error) {
var (
b bytes.Buffer
w = bufio.NewWriter(&b)
)
if err := gob.NewEncoder(w).Encode(struct {
Nodetype uint8
Val []byte
Key []byte
KeyOffset uint8
}{
st.nodeType,
st.val,
st.key,
uint8(st.keyOffset),
}); err != nil {
return nil, err
}
for _, child := range st.children {
if child == nil {
w.WriteByte(0)
continue
}
w.WriteByte(1)
if childData, err := child.MarshalBinary(); err != nil {
return nil, err
} else {
w.Write(childData)
}
}
w.Flush()
return b.Bytes(), nil
}

// UnmarshalBinary implements encoding.BinaryUnmarshaler
func (st *StackTrie) UnmarshalBinary(data []byte) error {
r := bytes.NewReader(data)
return st.unmarshalBinary(r)
}

func (st *StackTrie) unmarshalBinary(r io.Reader) error {
var dec struct {
Nodetype uint8
Val []byte
Key []byte
KeyOffset uint8
}
gob.NewDecoder(r).Decode(&dec)
st.nodeType = dec.Nodetype
st.val = dec.Val
st.key = dec.Key
st.keyOffset = int(dec.KeyOffset)

var hasChild = make([]byte, 1)
for i := range st.children {
if _, err := r.Read(hasChild); err != nil {
return err
} else if hasChild[0] == 0 {
continue
}
var child StackTrie
child.unmarshalBinary(r)
st.children[i] = &child
}
return nil
}

func (st *StackTrie) setDb(db ethdb.KeyValueWriter) {
st.db = db
for _, child := range st.children {
if child != nil {
child.setDb(db)
}
}
}

func newLeaf(ko int, key, val []byte, db ethdb.KeyValueWriter) *StackTrie {
st := stackTrieFromPool(db)
st.nodeType = leafNode
Expand Down
47 changes: 46 additions & 1 deletion trie/stacktrie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ func TestStacktrieNotModifyValues(t *testing.T) {
return big.NewInt(int64(i)).Bytes()
}
}

for i := 0; i < 1000; i++ {
key := common.BigToHash(keyB)
value := getValue(i)
Expand All @@ -168,5 +167,51 @@ func TestStacktrieNotModifyValues(t *testing.T) {
if !bytes.Equal(have, want) {
t.Fatalf("item %d, have %#x want %#x", i, have, want)
}

}
}

// TestStacktrieSerialization tests that the stacktrie works well if we
// serialize/unserialize it a lot
func TestStacktrieSerialization(t *testing.T) {
var (
st = NewStackTrie(nil)
nt, _ = New(common.Hash{}, NewDatabase(memorydb.New()))
keyB = big.NewInt(1)
keyDelta = big.NewInt(1)
vals [][]byte
keys [][]byte
)
getValue := func(i int) []byte {
if i%2 == 0 { // large
return crypto.Keccak256(big.NewInt(int64(i)).Bytes())
} else { //small
return big.NewInt(int64(i)).Bytes()
}
}
for i := 0; i < 10; i++ {
vals = append(vals, getValue(i))
keys = append(keys, common.BigToHash(keyB).Bytes())
keyB = keyB.Add(keyB, keyDelta)
keyDelta.Add(keyDelta, common.Big1)
}
for i, k := range keys {
nt.TryUpdate(k, common.CopyBytes(vals[i]))
}

for i, k := range keys {
blob, err := st.MarshalBinary()
if err != nil {
t.Fatal(err)
}
newSt, err := NewFromBinary(blob, nil)
if err != nil {
t.Fatal(err)
}
st = newSt
st.TryUpdate(k, common.CopyBytes(vals[i]))
}
if have, want := st.Hash(), nt.Hash(); have != want {
t.Fatalf("have %#x want %#x", have, want)
}
}