diff --git a/.gitignore b/.gitignore index eb4652d59e..20fea9b2fa 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ vendor/ # Release builds /taproot-assets-* +.aider* diff --git a/mssmt/compacted_tree.go b/mssmt/compacted_tree.go index abd654cc84..5c1c46234f 100644 --- a/mssmt/compacted_tree.go +++ b/mssmt/compacted_tree.go @@ -1,10 +1,20 @@ package mssmt import ( + "bytes" "context" "fmt" + "sort" ) +// BatchedInsertionEntry represents an entry used for batched +// insertions into the MS-SMT. It consists of a key and the +// associated leaf node to insert. +type BatchedInsertionEntry struct { + Key [32]byte + Leaf *LeafNode +} + // CompactedTree represents a compacted Merkle-Sum Sparse Merkle Tree (MS-SMT). // The tree has the same properties as a normal MS-SMT tree and is able to // create the same proofs and same root as the FullTree implemented in this @@ -308,6 +318,336 @@ func (t *CompactedTree) Insert(ctx context.Context, key [hashSize]byte, return t, nil } +// processCompactedLeaf handles the insertion of a batch of entries into a slot +// that is currently occupied by a compacted leaf. A compacted leaf represents a +// compressed subtree where all branches between a specific height and the actual +// leaf are assumed to be default. Depending on the batched insertion entries, +// this function determines whether to update (i.e. replace) the existing leaf or +// to merge it with a conflicting new entry. +// +// The logic is as follows: +// +// 1. When exactly one entry is provided: +// - If the entry's key matches the compacted leaf’s key, the function treats it +// as a replacement. It deletes the existing compacted leaf from the store and +// inserts a new compacted leaf built from the provided leaf data. +// - If the entry’s key differs from the compacted leaf’s key, a conflict is +// detected and the function calls the merge helper to combine the new leaf with +// the existing leaf into a merged branch. +// +// 2. When multiple entries are provided: +// - First, it checks whether all entries share the same key as the compacted leaf. +// If they do, the function performs a replacement using the data from the last entry +// in the batch. +// - Otherwise, it finds the first entry with a key that differs from the compacted leaf +// and then invokes the merge helper to merge that conflicting leaf with the current one. +// +// In every case, the function returns the updated node (either a new compacted leaf or a +// merged branch) and any error encountered during the processing. +func (t *CompactedTree) processCompactedLeaf(tx TreeStoreUpdateTx, height int, + entries []BatchedInsertionEntry, cl *CompactedLeafNode) (Node, error) { + + // processCompactedLeaf handles the case when the current child node is + // a compacted leaf. Depending on the batch of new entries, it will either + // replace the leaf or merge it with a conflicting entry. + + // Case 1: Only one new entry. + if len(entries) == 1 { + entry := entries[0] + if entry.Key == cl.Key() { + // Replacement: key matches, so update the compacted leaf with the + // new leaf data. + newLeaf := NewCompactedLeafNode(height+1, &entry.Key, entry.Leaf) + if err := tx.DeleteCompactedLeaf(cl.NodeHash()); err != nil { + return nil, err + } + if err := tx.InsertCompactedLeaf(newLeaf); err != nil { + return nil, err + } + return newLeaf, nil + } + // Conflict: key differs – merge the new entry with the existing leaf. + return t.merge(tx, height+1, entry.Key, entry.Leaf, cl.Key(), cl.LeafNode) + } + + // Case 2: Multiple entries. + // First, check whether every entry has the same key as the compacted leaf. + allMatch := true + for _, entry := range entries { + if entry.Key != cl.Key() { + allMatch = false + break + } + } + if allMatch { + // All entries match; replace with the last entry's data. + lastEntry := entries[len(entries)-1] + newLeaf := NewCompactedLeafNode(height+1, &lastEntry.Key, lastEntry.Leaf) + if err := tx.DeleteCompactedLeaf(cl.NodeHash()); err != nil { + return nil, err + } + if err := tx.InsertCompactedLeaf(newLeaf); err != nil { + return nil, err + } + return newLeaf, nil + } + + // Otherwise, find the first entry that differs and perform a merge. + var mergeEntry *BatchedInsertionEntry + for _, entry := range entries { + if entry.Key != cl.Key() { + mergeEntry = &entry + break + } + } + if mergeEntry == nil { + return nil, fmt.Errorf("unexpected nil merge entry") + } + return t.merge(tx, height+1, mergeEntry.Key, mergeEntry.Leaf, + cl.Key(), cl.LeafNode) +} + +// processSubtree processes a batch of insertion entries for a specific +// subtree at the given height. +// +// Depending on the current state of the child node at that subtree slot, +// this method determines how to incorporate the batched entries: +// +// 1. If the child is not the default empty node (i.e. it already contains +// non-default data): +// - If the child is a compacted leaf, processSubtree delegates to the +// processCompactedLeaf helper. This handles the case where the slot +// already has a compressed leaf, performing either a replacement (if +// the batched entry’s key matches) or merging conflicting entries. +// - Otherwise, the child is assumed to be a branch node, so the batched +// entries are recursively inserted into that branch via a call to +// batchedInsert at the next tree level. +// +// 2. If the child is the default empty node (i.e. no prior insertion has +// occurred in this slot): +// - For a single entry, a new compacted leaf is created at the next +// level using that entry’s key and leaf, and inserted directly. +// - For multiple entries, an empty branch node (from the precomputed +// EmptyTree for the next level) is used as the base to recursively +// process the entries using batchedInsert. +// +// In all cases, processSubtree returns the updated node that replaces the +// current child, along with any error encountered during processing. +// +// This helper reduces nesting by separating the logic for non-empty versus +// empty subtrees and for compacted leaf handling versus full branch recursion. +func (t *CompactedTree) processSubtree(tx TreeStoreUpdateTx, height int, + entries []BatchedInsertionEntry, child Node) (Node, error) { + + // If the child is not the default empty node, then we need to process + // it accordingly. + if child != EmptyTree[height+1] { + // If the child is a compacted leaf then delegate to our helper. + if cl, ok := child.(*CompactedLeafNode); ok { + return t.processCompactedLeaf(tx, height, entries, cl) + } + + // Otherwise, child is assumed to be a branch node: + baseChild := child.(*BranchNode) + return t.batchedInsert(tx, entries, height+1, baseChild) + } + + // If the child is empty: + if len(entries) == 1 { + // With a single entry, simply create a new compacted leaf. + entry := entries[0] + newLeaf := NewCompactedLeafNode(height+1, &entry.Key, entry.Leaf) + if err := tx.InsertCompactedLeaf(newLeaf); err != nil { + return nil, err + } + return newLeaf, nil + } + + // When multiple entries share an empty child, use an empty branch node + // to recursively process the batch. + baseChild := EmptyTree[height+1].(*BranchNode) + return t.batchedInsert(tx, entries, height+1, baseChild) +} + +// partitionEntries splits the given batched insertion entries into +// two slices based on the bit at the provided height. +// Entries with bit 0 go into leftEntries and those with bit 1 into rightEntries. +func partitionEntries(entries []BatchedInsertionEntry, height int) (leftEntries, rightEntries []BatchedInsertionEntry) { + for _, entry := range entries { + if bitIndex(uint8(height), &entry.Key) == 0 { + leftEntries = append(leftEntries, entry) + } else { + rightEntries = append(rightEntries, entry) + } + } + return +} + +// batchedInsert recursively processes a batch of insertion entries +// into the subtree of the current branch node at the specified height. +// +// The function works as follows: +// +// 1. Base-Case and Empty Batch: +// - If the current level (height) has reached or exceeded the last +// bit index, no further descent is possible, so the current branch +// is simply returned. +// - If there are no entries to insert, the current branch is returned +// unchanged. +// +// 2. Partitioning Entries: +// - The batch of insertion entries is split into two groups via the helper +// function partitionEntries. Entries with a 0 bit at the current level +// go into the leftEntries slice, and those with a 1 bit go into the +// rightEntries slice. +// +// 3. Recursively Updating Subtrees: +// - The current branch’s children are retrieved using GetChildren. +// - For each side that has corresponding entries: +// - processSubtree is invoked to update that subtree. +// - The updated child node (either a newly created compacted leaf or a +// recursively updated branch) replaces the old child. +// +// 4. Reassembling the Branch: +// - A new branch is constructed from the updated left and right children. +// - The old branch (if not the default empty branch) is deleted from the +// store. +// - If the newly formed branch is not equivalent to the default empty +// node for that height, it is inserted into the store. +// +// 5. Return Value: +// - The function returns the updated branch node reflecting all batched +// insertions at that level. +// +// This helper encapsulates the recursion and merging logic for batched insertions, +// reducing nesting in the higher-level BatchedInsert method and making the overall +// insertion process easier to follow and maintain. +func (t *CompactedTree) batchedInsert(tx TreeStoreUpdateTx, entries []BatchedInsertionEntry, height int, root *BranchNode) (*BranchNode, error) { + // Base-case: If we've reached the bottom, simply return the current branch. + if height >= lastBitIndex { + return root, nil + } + + // Guard against empty batch. + if len(entries) == 0 { + return root, nil + } + + leftEntries, rightEntries := partitionEntries(entries, height) + + // Get the current children from the node. + leftChild, rightChild, err := tx.GetChildren(height, root.NodeHash()) + if err != nil { + return nil, err + } + + // Process left subtree using the helper function. + if len(leftEntries) > 0 { + newLeft, err := t.processSubtree(tx, height, leftEntries, leftChild) + if err != nil { + return nil, err + } + leftChild = newLeft + } + + // Process right subtree using the helper function. + if len(rightEntries) > 0 { + newRight, err := t.processSubtree(tx, height, rightEntries, rightChild) + if err != nil { + return nil, err + } + rightChild = newRight + } + + // Create the updated branch from the new left and right children. + var updatedBranch *BranchNode + updatedBranch = NewBranch(leftChild, rightChild) + + // Delete the old branch and insert the new one. + if root != EmptyTree[height] { + if err := tx.DeleteBranch(root.NodeHash()); err != nil { + return nil, err + } + } + if !IsEqualNode(updatedBranch, EmptyTree[height]) { + if err := tx.InsertBranch(updatedBranch); err != nil { + return nil, err + } + } + + return updatedBranch, nil +} + +// BatchedInsert inserts multiple leaf nodes into the MS-SMT as a batch +// operation. +// +// It accepts a context and a slice of BatchedInsertionEntry, where each entry +// specifies a target key and its associated leaf to be inserted. The method +// performs the batch insertion in a transactional manner on the underlying +// tree store and proceeds through the following steps: +// +// 1. Sorting: +// - The entries are first sorted in lexicographic order based on their keys. +// This consistent ordering is essential for generating valid proofs and +// ensuring history-independence. +// +// 2. Overflow Check and Transaction Setup: +// - Within an update transaction, the current root branch is retrieved. +// - For each entry, the function checks whether adding the leaf’s sum to the +// current root sum would overflow a uint64. If an overflow is detected, +// the batch insertion aborts and returns an error. +// +// 3. Recursive Processing via Helper: +// - The sorted entries are passed to the recursive helper batchedInsert along +// with the current root and a starting height of 0. +// - The helper partitions the entries based on the bit at each tree level, +// recursively updating non-empty subtrees and creating new compacted leaves +// or branch nodes as needed. +// +// 4. Root Update: +// - When recursion completes, a new root reflecting all batched updates is +// obtained. This new root is then stored in the tree store, finalizing the +// update. +// +// 5. Return: +// - On success, BatchedInsert returns the updated tree instance (implementing +// the Tree interface). In case of any errors (e.g. overflow or transactional +// failures), it returns the appropriate error. +// +// This method encapsulates the complex merging and partitioning logic required for +// efficient batched insertions into a compacted Merkle-Sum Sparse Merkle Tree. +func (t *CompactedTree) BatchedInsert(ctx context.Context, entries []BatchedInsertionEntry) (Tree, error) { + sort.Slice(entries, func(i, j int) bool { + return bytes.Compare(entries[i].Key[:], entries[j].Key[:]) < 0 + }) + + err := t.store.Update(ctx, func(tx TreeStoreUpdateTx) error { + currentRoot, err := tx.RootNode() + if err != nil { + return err + } + branchRoot := currentRoot.(*BranchNode) + + // (Optional) Loop over entries and check for sum overflow. + for _, entry := range entries { + if err := CheckSumOverflowUint64(branchRoot.NodeSum(), entry.Leaf.NodeSum()); err != nil { + return fmt.Errorf("batched insert key %v sum overflow: %w", entry.Key, err) + } + } + + // Call the new batchedInsert method. + newRoot, err := t.batchedInsert(tx, entries, 0, branchRoot) + if err != nil { + return err + } + return tx.UpdateRoot(newRoot) + }) + if err != nil { + return nil, err + } + return t, nil +} + // Delete deletes the leaf node found at the given key within the MS-SMT. func (t *CompactedTree) Delete(ctx context.Context, key [hashSize]byte) ( Tree, error) { diff --git a/mssmt/node.go b/mssmt/node.go index dbfeff6f0c..e7bad1e095 100644 --- a/mssmt/node.go +++ b/mssmt/node.go @@ -121,6 +121,9 @@ type CompactedLeafNode struct { // compactedNodeHash holds the topmost (omitted) node's node hash in the // subtree. compactedNodeHash NodeHash + + // Height is the level at which this compacted leaf was created. + Height int } // NewCompactedLeafNode creates a new compacted leaf at the passed height with @@ -144,6 +147,7 @@ func NewCompactedLeafNode(height int, key *[32]byte, compactedNodeHash: nodeHash, } + node.Height = height return node } diff --git a/mssmt/tree_test.go b/mssmt/tree_test.go index 9395715044..a3a1b5dfd5 100644 --- a/mssmt/tree_test.go +++ b/mssmt/tree_test.go @@ -1,5 +1,3 @@ -//go:build !race - package mssmt_test import ( @@ -13,13 +11,101 @@ import ( "strconv" "testing" + "crypto/sha256" "github.com/lightninglabs/taproot-assets/fn" "github.com/lightninglabs/taproot-assets/internal/test" "github.com/lightninglabs/taproot-assets/mssmt" _ "github.com/lightninglabs/taproot-assets/tapdb" + "github.com/stretchr/testify/require" ) +// TestBatchedInsert verifies that a batch of leaves is inserted correctly +// and that each inserted element can be retrieved. +func TestBatchedInsert(t *testing.T) { + ctx := context.Background() + numLeaves := 10 + var leaves []struct { + key [32]byte + leaf *mssmt.LeafNode + } + for i := 0; i < numLeaves; i++ { + key := sha256.Sum256([]byte(fmt.Sprintf("value-%d", i))) + value := []byte(fmt.Sprintf("leaf-%d", i)) + leaves = append(leaves, struct { + key [32]byte + leaf *mssmt.LeafNode + }{ + key: key, + leaf: mssmt.NewLeafNode(value, uint64(i+1)), + }) + } + + store := mssmt.NewDefaultStore() + compTree := mssmt.NewCompactedTree(store) + + // Build the batch. + var batch []mssmt.BatchedInsertionEntry + for _, tl := range leaves { + batch = append(batch, mssmt.BatchedInsertionEntry{ + Key: tl.key, + Leaf: tl.leaf, + }) + } + + newTree, err := compTree.BatchedInsert(ctx, batch) + require.NoError(t, err) + + // Verify that each inserted leaf can be retrieved. + for _, entry := range batch { + retrieved, err := newTree.Get(ctx, entry.Key) + require.NoError(t, err) + require.Equal(t, entry.Leaf, retrieved, "mismatch for key %x", entry.Key) + } +} + +// TestBatchedInsertEmpty ensures that calling BatchedInsert with an empty batch +// leaves the tree unchanged. +func TestBatchedInsertEmpty(t *testing.T) { + ctx := context.Background() + store := mssmt.NewDefaultStore() + compTree := mssmt.NewCompactedTree(store) + + newTree, err := compTree.BatchedInsert(ctx, []mssmt.BatchedInsertionEntry{}) + require.NoError(t, err) + root, err := newTree.Root(ctx) + require.NoError(t, err) + require.True(t, mssmt.IsEqualNode(root, mssmt.EmptyTree[0])) +} + +// TestBatchedInsertOverflow verifies that a batch insertion causing a sum overflow +// returns an error. +func TestBatchedInsertOverflow(t *testing.T) { + ctx := context.Background() + store := mssmt.NewDefaultStore() + compTree := mssmt.NewCompactedTree(store) + + // Insert one leaf with a huge sum. + huge := uint64(math.MaxUint64 - 100) + key1 := [32]byte{1} + hugeLeaf := mssmt.NewLeafNode([]byte("huge"), huge) + _, err := compTree.Insert(ctx, key1, hugeLeaf) + require.NoError(t, err) + + // Prepare a batch with one normal leaf and one that overflows the root sum. + key2 := [32]byte{2} + normalLeaf := mssmt.NewLeafNode([]byte("normal"), 50) + key3 := [32]byte{3} + overflowLeaf := mssmt.NewLeafNode([]byte("overflow"), 101) // huge + 101 exceeds MaxUint64 + batch := []mssmt.BatchedInsertionEntry{ + {Key: key2, Leaf: normalLeaf}, + {Key: key3, Leaf: overflowLeaf}, + } + _, err = compTree.BatchedInsert(ctx, batch) + require.Error(t, err) + require.ErrorIs(t, err, mssmt.ErrIntegerOverflow) +} + var ( errorTestVectorName = "mssmt_tree_error_cases.json" deletionTestVectorName = "mssmt_tree_deletion.json"