Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ type SparseMerkleTree struct {
root []byte
}

// LeafNode represents the contents of a tree leaf.
// A LeafNode with ValueHash == nil represents an absent record.
type LeafNode struct {
Path []byte
ValueHash []byte
}

// NewSparseMerkleTree creates a new Sparse Merkle tree on an empty MapStore.
func NewSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, options ...Option) *SparseMerkleTree {
smt := SparseMerkleTree{
Expand Down Expand Up @@ -91,6 +98,55 @@ func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) {
return value, nil
}

// GetLeaf gets an entire leaf node from the tree.
// If the leaf is not found, a LeafNode is returned with ValueHash == nil
func (smt *SparseMerkleTree) GetLeaf(key []byte) (*LeafNode, error) {
path := smt.th.path(key)
if bytes.Equal(smt.root, smt.th.placeholder()) {
// The tree is empty, return the default value.
return &LeafNode{Path: path, ValueHash: nil}, nil
}

currentHash := smt.root
for i := 0; i < smt.depth(); i++ {
Copy link
Copy Markdown
Collaborator

@i-norden i-norden Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outside the scope of this PR butI wonder if it would be worthwhile to create a secondary index, internal to the SMT, that maps a key/path directly to the associated leaf node so that we don't have to iterate down a branch from the root- performing multiple db lookups- to get to the leaf (or discover it is nonexistent). Or rather than internal to the SMT this could be done similar to the state snapshot in Ethereum: ethereum/go-ethereum#20152, where it is handled at levels above the tree.

currentData, err := smt.nodes.Get(currentHash)
if err != nil {
return nil, err
} else if smt.th.isLeaf(currentData) {
// We've reached the end. Is this the actual leaf?
p, valueHash := smt.th.parseLeaf(currentData)
if !bytes.Equal(path, p) {
// Nope. Therefore the key is actually empty.
return &LeafNode{Path: p, ValueHash: nil}, nil
}
// Otherwise, yes. Return the value.
return &LeafNode{Path: p, ValueHash: valueHash}, nil
}

leftNode, rightNode := smt.th.parseNode(currentData)
if getBitAtFromMSB(path, i) == right {
currentHash = rightNode
} else {
currentHash = leftNode
}

if bytes.Equal(currentHash, smt.th.placeholder()) {
// We've hit a placeholder value; this is the end.
return nil, nil
}
}

// The following lines of code should only be reached if the path is 256
// nodes high, which should be very unlikely if the underlying hash function
// is collision-resistant.
currentData, err := smt.nodes.Get(currentHash)
if err != nil {
return nil, err
}
_, valueHash := smt.th.parseLeaf(currentData)
return &LeafNode{Path: path, ValueHash: valueHash}, nil
}

// Has returns true if the value at the given key is non-default, false
// otherwise.
func (smt *SparseMerkleTree) Has(key []byte) (bool, error) {
Expand Down
105 changes: 92 additions & 13 deletions smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,22 @@ func TestSparseMerkleTreeKnown(t *testing.T) {
}
}

// Make two neighboring keys.
func neighboringKeys(size int) ([]byte, []byte) {
// The dummy hash function expects keys to be prefixed with four bytes of 0,
// which will cause it to return the preimage itself as the digest, without
// the first four bytes.
key1 := make([]byte, size+4)
rand.Read(key1)
key1[0], key1[1], key1[2], key1[3] = byte(0), byte(0), byte(0), byte(0)
key1[size+4-1] = byte(0)
key2 := make([]byte, size+4)
copy(key2, key1)
// We make key2's least significant bit different than key1's
key2[size+4-1] = byte(1)
return key1, key2
}

// Test tree operations when two leafs are immediate neighbors.
func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
h := newDummyHasher(sha256.New())
Expand All @@ -228,19 +244,7 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) {
var value []byte
var err error

// Make two neighboring keys.
//
// The dummy hash function expects keys to prefixed with four bytes of 0,
// which will cause it to return the preimage itself as the digest, without
// the first four bytes.
key1 := make([]byte, h.Size()+4)
rand.Read(key1)
key1[0], key1[1], key1[2], key1[3] = byte(0), byte(0), byte(0), byte(0)
key1[h.Size()+4-1] = byte(0)
key2 := make([]byte, h.Size()+4)
copy(key2, key1)
// We make key2's least significant bit different than key1's
key2[h.Size()+4-1] = byte(1)
key1, key2 := neighboringKeys(h.Size())

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
Expand Down Expand Up @@ -614,3 +618,78 @@ func TestOrphanRemoval(t *testing.T) {
}
})
}

func TestGetLeaf(t *testing.T) {
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
var leaf *LeafNode
var err error

t.Run("basic", func(t *testing.T) {
leaf, err = smt.GetLeaf([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting empty key: %v", err)
}
if leaf.ValueHash != nil {
t.Error("did not get nil ValueHash when getting empty key")
}
if leaf.Path == nil {
t.Error("got nil Path when getting empty key")
}

_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
leaf, err = smt.GetLeaf([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal(leaf.Path, smt.th.path([]byte("testKey"))) {
t.Error("did not get correct path when getting non-empty key")
}
if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue"))) {
t.Error("did not get correct value hash when getting non-empty key")
}
})

h := newDummyHasher(sha256.New())
smn, smv = NewSimpleMap(), NewSimpleMap()
smt = NewSparseMerkleTree(smn, smv, h)

// Max height case
t.Run("max height case", func(t *testing.T) {
key1, key2 := neighboringKeys(h.Size())

_, err = smt.Update(key1, []byte("testValue1"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
_, err = smt.Update(key2, []byte("testValue2"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}

leaf, err = smt.GetLeaf([]byte(key1))
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal(leaf.Path, smt.th.path([]byte(key1))) {
t.Error("did not get correct path when getting non-empty key")
}
if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue1"))) {
t.Error("did not get correct value hash when getting non-empty key")
}

leaf, err = smt.GetLeaf([]byte(key2))
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal(leaf.Path, smt.th.path([]byte(key2))) {
t.Error("did not get correct path when getting non-empty key")
}
if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue2"))) {
t.Error("did not get correct value hash when getting non-empty key")
}
})
}