diff --git a/smt.go b/smt.go index 52dba01..b2fe13b 100644 --- a/smt.go +++ b/smt.go @@ -22,6 +22,13 @@ type SparseMerkleTree struct { root []byte } +// LeafNode represents the contents of a tree leaf. +// A LeafNode with ValueHash == nil represents an absent record. +type LeafNode struct { + Path []byte + ValueHash []byte +} + // NewSparseMerkleTree creates a new Sparse Merkle tree on an empty MapStore. func NewSparseMerkleTree(nodes, values MapStore, hasher hash.Hash, options ...Option) *SparseMerkleTree { smt := SparseMerkleTree{ @@ -91,6 +98,55 @@ func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) { return value, nil } +// GetLeaf gets an entire leaf node from the tree. +// If the leaf is not found, a LeafNode is returned with ValueHash == nil +func (smt *SparseMerkleTree) GetLeaf(key []byte) (*LeafNode, error) { + path := smt.th.path(key) + if bytes.Equal(smt.root, smt.th.placeholder()) { + // The tree is empty, return the default value. + return &LeafNode{Path: path, ValueHash: nil}, nil + } + + currentHash := smt.root + for i := 0; i < smt.depth(); i++ { + currentData, err := smt.nodes.Get(currentHash) + if err != nil { + return nil, err + } else if smt.th.isLeaf(currentData) { + // We've reached the end. Is this the actual leaf? + p, valueHash := smt.th.parseLeaf(currentData) + if !bytes.Equal(path, p) { + // Nope. Therefore the key is actually empty. + return &LeafNode{Path: p, ValueHash: nil}, nil + } + // Otherwise, yes. Return the value. + return &LeafNode{Path: p, ValueHash: valueHash}, nil + } + + leftNode, rightNode := smt.th.parseNode(currentData) + if getBitAtFromMSB(path, i) == right { + currentHash = rightNode + } else { + currentHash = leftNode + } + + if bytes.Equal(currentHash, smt.th.placeholder()) { + // We've hit a placeholder value; this is the end. + return nil, nil + } + } + + // The following lines of code should only be reached if the path is 256 + // nodes high, which should be very unlikely if the underlying hash function + // is collision-resistant. + currentData, err := smt.nodes.Get(currentHash) + if err != nil { + return nil, err + } + _, valueHash := smt.th.parseLeaf(currentData) + return &LeafNode{Path: path, ValueHash: valueHash}, nil +} + // Has returns true if the value at the given key is non-default, false // otherwise. func (smt *SparseMerkleTree) Has(key []byte) (bool, error) { diff --git a/smt_test.go b/smt_test.go index 43f4b39..2e193f1 100644 --- a/smt_test.go +++ b/smt_test.go @@ -220,6 +220,22 @@ func TestSparseMerkleTreeKnown(t *testing.T) { } } +// Make two neighboring keys. +func neighboringKeys(size int) ([]byte, []byte) { + // The dummy hash function expects keys to be prefixed with four bytes of 0, + // which will cause it to return the preimage itself as the digest, without + // the first four bytes. + key1 := make([]byte, size+4) + rand.Read(key1) + key1[0], key1[1], key1[2], key1[3] = byte(0), byte(0), byte(0), byte(0) + key1[size+4-1] = byte(0) + key2 := make([]byte, size+4) + copy(key2, key1) + // We make key2's least significant bit different than key1's + key2[size+4-1] = byte(1) + return key1, key2 +} + // Test tree operations when two leafs are immediate neighbors. func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { h := newDummyHasher(sha256.New()) @@ -228,19 +244,7 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { var value []byte var err error - // Make two neighboring keys. - // - // The dummy hash function expects keys to prefixed with four bytes of 0, - // which will cause it to return the preimage itself as the digest, without - // the first four bytes. - key1 := make([]byte, h.Size()+4) - rand.Read(key1) - key1[0], key1[1], key1[2], key1[3] = byte(0), byte(0), byte(0), byte(0) - key1[h.Size()+4-1] = byte(0) - key2 := make([]byte, h.Size()+4) - copy(key2, key1) - // We make key2's least significant bit different than key1's - key2[h.Size()+4-1] = byte(1) + key1, key2 := neighboringKeys(h.Size()) _, err = smt.Update(key1, []byte("testValue1")) if err != nil { @@ -614,3 +618,78 @@ func TestOrphanRemoval(t *testing.T) { } }) } + +func TestGetLeaf(t *testing.T) { + smn, smv := NewSimpleMap(), NewSimpleMap() + smt := NewSparseMerkleTree(smn, smv, sha256.New()) + var leaf *LeafNode + var err error + + t.Run("basic", func(t *testing.T) { + leaf, err = smt.GetLeaf([]byte("testKey")) + if err != nil { + t.Errorf("returned error when getting empty key: %v", err) + } + if leaf.ValueHash != nil { + t.Error("did not get nil ValueHash when getting empty key") + } + if leaf.Path == nil { + t.Error("got nil Path when getting empty key") + } + + _, err = smt.Update([]byte("testKey"), []byte("testValue")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + leaf, err = smt.GetLeaf([]byte("testKey")) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal(leaf.Path, smt.th.path([]byte("testKey"))) { + t.Error("did not get correct path when getting non-empty key") + } + if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue"))) { + t.Error("did not get correct value hash when getting non-empty key") + } + }) + + h := newDummyHasher(sha256.New()) + smn, smv = NewSimpleMap(), NewSimpleMap() + smt = NewSparseMerkleTree(smn, smv, h) + + // Max height case + t.Run("max height case", func(t *testing.T) { + key1, key2 := neighboringKeys(h.Size()) + + _, err = smt.Update(key1, []byte("testValue1")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + _, err = smt.Update(key2, []byte("testValue2")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + + leaf, err = smt.GetLeaf([]byte(key1)) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal(leaf.Path, smt.th.path([]byte(key1))) { + t.Error("did not get correct path when getting non-empty key") + } + if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue1"))) { + t.Error("did not get correct value hash when getting non-empty key") + } + + leaf, err = smt.GetLeaf([]byte(key2)) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal(leaf.Path, smt.th.path([]byte(key2))) { + t.Error("did not get correct path when getting non-empty key") + } + if !bytes.Equal(leaf.ValueHash, smt.th.digest([]byte("testValue2"))) { + t.Error("did not get correct value hash when getting non-empty key") + } + }) +}