diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp index b58aa1250645..a3e101135878 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp @@ -60,7 +60,7 @@ template class ContentAddressedAppendOn using UnwindBlockCallback = std::function&)>; using FinalizeBlockCallback = EmptyResponseCallback; using GetBlockForIndexCallback = std::function&)>; - using CheckpointCallback = EmptyResponseCallback; + using CheckpointCallback = std::function&)>; using CheckpointCommitCallback = EmptyResponseCallback; using CheckpointRevertCallback = EmptyResponseCallback; @@ -254,8 +254,11 @@ template class ContentAddressedAppendOn void checkpoint(const CheckpointCallback& on_completion); void commit_checkpoint(const CheckpointCommitCallback& on_completion); void revert_checkpoint(const CheckpointRevertCallback& on_completion); - void commit_all_checkpoints(const CheckpointCommitCallback& on_completion); - void revert_all_checkpoints(const CheckpointRevertCallback& on_completion); + void commit_all_checkpoints_to(const CheckpointCommitCallback& on_completion); + void revert_all_checkpoints_to(const CheckpointRevertCallback& on_completion); + void commit_to_depth(uint32_t target_depth, const CheckpointCommitCallback& on_completion); + void revert_to_depth(uint32_t target_depth, const CheckpointRevertCallback& on_completion); + uint32_t checkpoint_depth() const; protected: using ReadTransaction = typename Store::ReadTransaction; @@ -1002,7 +1005,11 @@ void ContentAddressedAppendOnlyTree::rollback(const Rollba template void ContentAddressedAppendOnlyTree::checkpoint(const CheckpointCallback& on_completion) { - auto job = [=, this]() { execute_and_report([=, this]() { store_->checkpoint(); }, on_completion); }; + auto job = [=, this]() { + execute_and_report( + [=, this](TypedResponse& response) { response.inner.depth = store_->checkpoint(); }, + on_completion); + }; workers_->enqueue(job); } @@ -1023,21 +1030,46 @@ void ContentAddressedAppendOnlyTree::revert_checkpoint( } template -void ContentAddressedAppendOnlyTree::commit_all_checkpoints( +void ContentAddressedAppendOnlyTree::commit_all_checkpoints_to( const CheckpointCommitCallback& on_completion) { - auto job = [=, this]() { execute_and_report([=, this]() { store_->commit_all_checkpoints(); }, on_completion); }; + auto job = [=, this]() { execute_and_report([=, this]() { store_->commit_all_checkpoints_to(); }, on_completion); }; workers_->enqueue(job); } template -void ContentAddressedAppendOnlyTree::revert_all_checkpoints( +void ContentAddressedAppendOnlyTree::revert_all_checkpoints_to( const CheckpointRevertCallback& on_completion) { - auto job = [=, this]() { execute_and_report([=, this]() { store_->revert_all_checkpoints(); }, on_completion); }; + auto job = [=, this]() { execute_and_report([=, this]() { store_->revert_all_checkpoints_to(); }, on_completion); }; + workers_->enqueue(job); +} + +template +void ContentAddressedAppendOnlyTree::commit_to_depth( + uint32_t target_depth, const CheckpointCommitCallback& on_completion) +{ + auto job = [=, this]() { + execute_and_report([=, this]() { store_->commit_to_depth(target_depth); }, on_completion); + }; + workers_->enqueue(job); +} + +template +void ContentAddressedAppendOnlyTree::revert_to_depth( + uint32_t target_depth, const CheckpointRevertCallback& on_completion) +{ + auto job = [=, this]() { + execute_and_report([=, this]() { store_->revert_to_depth(target_depth); }, on_completion); + }; workers_->enqueue(job); } +template +uint32_t ContentAddressedAppendOnlyTree::checkpoint_depth() const +{ + return store_->checkpoint_depth(); +} template void ContentAddressedAppendOnlyTree::remove_historic_block( const block_number_t& blockNumber, const RemoveHistoricBlockCallback& on_completion) diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp index cecff513bb46..1518b2d2bfa5 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp @@ -2171,7 +2171,7 @@ TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_checkpoint_and_revert_fo commit_checkpoint_tree(tree, false); } -TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_commit_all_checkpoints) +TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_commit_all_checkpoints_to) { constexpr size_t depth = 10; uint32_t blockSize = 16; @@ -2223,7 +2223,7 @@ TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_commit_all_checkpoints) commit_checkpoint_tree(tree, false); } -TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_revert_all_checkpoints) +TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_revert_all_checkpoints_to) { constexpr size_t depth = 10; uint32_t blockSize = 16; @@ -2274,3 +2274,95 @@ TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_revert_all_checkpoints) revert_checkpoint_tree(tree, false); commit_checkpoint_tree(tree, false); } + +TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_commit_to_depth) +{ + constexpr size_t depth = 10; + uint32_t blockSize = 16; + std::string name = random_string(); + ThreadPoolPtr pool = make_thread_pool(1); + LMDBTreeStore::SharedPtr db = std::make_shared(_directory, name, _mapSize, _maxReaders); + + { + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + std::vector values = create_values(blockSize); + add_values(tree, values); + commit_tree(tree); + } + + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + + // Capture initial state + fr_sibling_path initial_path = get_sibling_path(tree, 0); + + // Depth 1 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + fr_sibling_path after_depth1_path = get_sibling_path(tree, 0); + + // Depth 2 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + + // Depth 3 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + fr_sibling_path after_depth3_path = get_sibling_path(tree, 0); + + // Commit depths 3 and 2 into depth 1, leaving depth at 1 + commit_tree_to_depth(tree, 1); + + // Data from all depths should be present + check_sibling_path(tree, 0, after_depth3_path); + + // Revert depth 1 — should go back to initial state + revert_checkpoint_tree(tree); + check_sibling_path(tree, 0, initial_path); +} + +TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_revert_to_depth) +{ + constexpr size_t depth = 10; + uint32_t blockSize = 16; + std::string name = random_string(); + ThreadPoolPtr pool = make_thread_pool(1); + LMDBTreeStore::SharedPtr db = std::make_shared(_directory, name, _mapSize, _maxReaders); + + { + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + std::vector values = create_values(blockSize); + add_values(tree, values); + commit_tree(tree); + } + + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + + // Depth 1 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + fr_sibling_path after_depth1_path = get_sibling_path(tree, 0); + + // Depth 2 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + + // Depth 3 + checkpoint_tree(tree); + add_values(tree, create_values(blockSize)); + + // Revert depths 3 and 2, leaving depth at 1 + revert_tree_to_depth(tree, 1); + + // Should be back to after depth 1 state + check_sibling_path(tree, 0, after_depth1_path); + + // Depth 1 still active — commit it + commit_checkpoint_tree(tree); + + // Should still have depth 1 data + check_sibling_path(tree, 0, after_depth1_path); +} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp index f1f078e68a1e..2f08ef61ef2b 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp @@ -191,11 +191,14 @@ template class ContentAddressedCachedTreeStore { std::optional find_block_for_index(const index_t& index, ReadTransaction& tx) const; - void checkpoint(); + uint32_t checkpoint(); void revert_checkpoint(); void commit_checkpoint(); - void revert_all_checkpoints(); - void commit_all_checkpoints(); + void revert_all_checkpoints_to(); + void commit_all_checkpoints_to(); + void commit_to_depth(uint32_t depth); + void revert_to_depth(uint32_t depth); + uint32_t checkpoint_depth() const; private: using Cache = ContentAddressedCache; @@ -276,10 +279,10 @@ ContentAddressedCachedTreeStore::ContentAddressedCachedTreeStore( // These checkpoint apis modify the cache's internal state. // They acquire the mutex to prevent races with concurrent read/write operations (e.g., when C++ AVM simulation // runs on a worker thread while TypeScript calls revert_checkpoint from a timeout handler). -template void ContentAddressedCachedTreeStore::checkpoint() +template uint32_t ContentAddressedCachedTreeStore::checkpoint() { std::unique_lock lock(mtx_); - cache_.checkpoint(); + return cache_.checkpoint(); } template void ContentAddressedCachedTreeStore::revert_checkpoint() @@ -294,18 +297,36 @@ template void ContentAddressedCachedTreeStore void ContentAddressedCachedTreeStore::revert_all_checkpoints() +template void ContentAddressedCachedTreeStore::revert_all_checkpoints_to() { std::unique_lock lock(mtx_); cache_.revert_all(); } -template void ContentAddressedCachedTreeStore::commit_all_checkpoints() +template void ContentAddressedCachedTreeStore::commit_all_checkpoints_to() { std::unique_lock lock(mtx_); cache_.commit_all(); } +template void ContentAddressedCachedTreeStore::commit_to_depth(uint32_t depth) +{ + std::unique_lock lock(mtx_); + cache_.commit_to_depth(depth); +} + +template void ContentAddressedCachedTreeStore::revert_to_depth(uint32_t depth) +{ + std::unique_lock lock(mtx_); + cache_.revert_to_depth(depth); +} + +template uint32_t ContentAddressedCachedTreeStore::checkpoint_depth() const +{ + std::unique_lock lock(mtx_); + return cache_.depth(); +} + template index_t ContentAddressedCachedTreeStore::constrain_tree_size_to_only_committed( const RequestContext& requestContext, ReadTransaction& tx) const diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp index e9fa2bfcfaf2..908d3ccde1cb 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp @@ -47,11 +47,14 @@ template class ContentAddressedCache { ContentAddressedCache& operator=(ContentAddressedCache&& other) noexcept = default; bool operator==(const ContentAddressedCache& other) const = default; - void checkpoint(); + uint32_t checkpoint(); void revert(); void commit(); void revert_all(); void commit_all(); + void commit_to_depth(uint32_t depth); + void revert_to_depth(uint32_t depth); + uint32_t depth() const; void reset(uint32_t depth); std::pair find_low_value(const uint256_t& new_leaf_key, @@ -126,9 +129,10 @@ template ContentAddressedCache::ContentA reset(depth); } -template void ContentAddressedCache::checkpoint() +template uint32_t ContentAddressedCache::checkpoint() { journals_.emplace_back(Journal(meta_)); + return static_cast(journals_.size()); } template void ContentAddressedCache::revert() @@ -240,6 +244,31 @@ template void ContentAddressedCache::rev revert(); } } +template uint32_t ContentAddressedCache::depth() const +{ + return static_cast(journals_.size()); +} + +template void ContentAddressedCache::commit_to_depth(uint32_t target_depth) +{ + if (target_depth >= journals_.size()) { + throw std::runtime_error("Invalid depth for commit_to_depth"); + } + while (journals_.size() > target_depth) { + commit(); + } +} + +template void ContentAddressedCache::revert_to_depth(uint32_t target_depth) +{ + if (target_depth >= journals_.size()) { + throw std::runtime_error("Invalid depth for revert_to_depth"); + } + while (journals_.size() > target_depth) { + revert(); + } +} + template void ContentAddressedCache::reset(uint32_t depth) { nodes_ = std::unordered_map(); diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp index 5e6325244a40..e690308530a0 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp @@ -590,3 +590,210 @@ TEST_F(ContentAddressedCacheTest, reverts_remove_all_deeper_commits_2) reverts_remove_all_deeper_commits_2(max_index, depth, num_levels); } } + +TEST_F(ContentAddressedCacheTest, checkpoint_returns_depth) +{ + CacheType cache = create_cache(40); + EXPECT_EQ(cache.depth(), 0u); + EXPECT_EQ(cache.checkpoint(), 1u); + EXPECT_EQ(cache.checkpoint(), 2u); + EXPECT_EQ(cache.checkpoint(), 3u); + EXPECT_EQ(cache.depth(), 3u); +} + +TEST_F(ContentAddressedCacheTest, depth_reports_journal_count) +{ + CacheType cache = create_cache(40); + EXPECT_EQ(cache.depth(), 0u); + cache.checkpoint(); + EXPECT_EQ(cache.depth(), 1u); + cache.checkpoint(); + EXPECT_EQ(cache.depth(), 2u); + cache.commit(); + EXPECT_EQ(cache.depth(), 1u); + cache.revert(); + EXPECT_EQ(cache.depth(), 0u); +} + +TEST_F(ContentAddressedCacheTest, commit_to_depth_partial) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + CacheType original_cache = cache; + + // Depth 1: base checkpoint + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + + // Depth 2 + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + + // Depth 3 + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + + CacheType final_cache = cache; + + // Commit down to depth 1 (commits depths 3 and 2), preserve depth 1 + cache.commit_to_depth(1); + EXPECT_EQ(cache.depth(), 1u); + + // Data from depth 2+3 is merged into depth 1's scope + EXPECT_TRUE(final_cache.is_equivalent_to(cache)); + + // Now revert depth 1 — should go back to original + cache.revert(); + EXPECT_EQ(cache.depth(), 0u); + EXPECT_TRUE(original_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, revert_to_depth_partial) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + + // Depth 1: base checkpoint + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + CacheType after_depth1_cache = cache; + + // Depth 2 + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + + // Depth 3 + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + + // Revert down to depth 1 (reverts depths 3 and 2), preserve depth 1 + cache.revert_to_depth(1); + EXPECT_EQ(cache.depth(), 1u); + + // Data from depth 2+3 is gone, state matches after depth 1 changes + EXPECT_TRUE(after_depth1_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, commit_to_depth_0_is_commit_all) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + CacheType final_cache = cache; + + cache.commit_to_depth(0); + EXPECT_EQ(cache.depth(), 0u); + EXPECT_TRUE(final_cache.is_equivalent_to(cache)); + + // No more operations possible + EXPECT_THROW(cache.commit(), std::runtime_error); + EXPECT_THROW(cache.revert(), std::runtime_error); +} + +TEST_F(ContentAddressedCacheTest, revert_to_depth_0_is_revert_all) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + CacheType original_cache = cache; + + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + + cache.revert_to_depth(0); + EXPECT_EQ(cache.depth(), 0u); + EXPECT_TRUE(original_cache.is_equivalent_to(cache)); + + EXPECT_THROW(cache.commit(), std::runtime_error); + EXPECT_THROW(cache.revert(), std::runtime_error); +} + +TEST_F(ContentAddressedCacheTest, commit_to_depth_at_current_is_single_commit) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + CacheType final_cache = cache; + + // Commit only the top checkpoint (depth 3), leaving depth at 2 + EXPECT_EQ(cache.depth(), 3u); + cache.commit_to_depth(2); + EXPECT_EQ(cache.depth(), 2u); + EXPECT_TRUE(final_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, revert_to_depth_at_current_is_single_revert) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + CacheType after_depth2_cache = cache; + + cache.checkpoint(); + add_to_cache(cache, 300, 100, 1000); + + // Revert only the top checkpoint (depth 3), leaving depth at 2 + EXPECT_EQ(cache.depth(), 3u); + cache.revert_to_depth(2); + EXPECT_EQ(cache.depth(), 2u); + EXPECT_TRUE(after_depth2_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, revert_to_depth_preserves_lower_data) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 100, 1000); + CacheType original_cache = cache; + + // Depth 1 + cache.checkpoint(); + add_to_cache(cache, 100, 100, 1000); + CacheType after_depth1_cache = cache; + + // Depth 2 + cache.checkpoint(); + add_to_cache(cache, 200, 100, 1000); + + // Revert depth 2 only, leaving depth at 1 + EXPECT_EQ(cache.depth(), 2u); + cache.revert_to_depth(1); + EXPECT_EQ(cache.depth(), 1u); + EXPECT_TRUE(after_depth1_cache.is_equivalent_to(cache)); + + // Commit depth 1 — depth 1 data persists + cache.commit(); + EXPECT_EQ(cache.depth(), 0u); + EXPECT_TRUE(after_depth1_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, commit_to_depth_invalid_depth_throws) +{ + CacheType cache = create_cache(40); + cache.checkpoint(); + cache.checkpoint(); + EXPECT_EQ(cache.depth(), 2u); + + // target_depth >= current depth is invalid + EXPECT_THROW(cache.commit_to_depth(2), std::runtime_error); + EXPECT_THROW(cache.commit_to_depth(3), std::runtime_error); + EXPECT_THROW(cache.revert_to_depth(2), std::runtime_error); + EXPECT_THROW(cache.revert_to_depth(3), std::runtime_error); +} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/response.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/response.hpp index c2943f31c2fe..d5461a43adec 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/response.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/response.hpp @@ -32,6 +32,17 @@ struct TreeMetaResponse { TreeMetaResponse& operator=(TreeMetaResponse&& other) noexcept = default; }; +struct CheckpointResponse { + uint32_t depth; + + CheckpointResponse() = default; + ~CheckpointResponse() = default; + CheckpointResponse(const CheckpointResponse& other) = default; + CheckpointResponse(CheckpointResponse&& other) noexcept = default; + CheckpointResponse& operator=(const CheckpointResponse& other) = default; + CheckpointResponse& operator=(CheckpointResponse&& other) noexcept = default; +}; + struct AddDataResponse { index_t size; fr root; diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp index ad736e292900..e7a56a52848f 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp @@ -257,10 +257,18 @@ template void rollback_tree(TreeType& tree) call_operation(completion); } -template void checkpoint_tree(TreeType& tree) +template uint32_t checkpoint_tree(TreeType& tree) { - auto completion = [&](auto completion) { tree.checkpoint(completion); }; - call_operation(completion); + Signal signal; + uint32_t depth = 0; + auto completion = [&](const TypedResponse& response) -> void { + EXPECT_EQ(response.success, true); + depth = response.inner.depth; + signal.signal_level(); + }; + tree.checkpoint(completion); + signal.wait_for_level(); + return depth; } template void commit_checkpoint_tree(TreeType& tree, bool expected_success = true) @@ -279,13 +287,25 @@ template void revert_checkpoint_tree(TreeType& tree, bool ex template void commit_all_tree_checkpoints(TreeType& tree, bool expected_success = true) { - auto completion = [&](auto completion) { tree.commit_all_checkpoints(completion); }; + auto completion = [&](auto completion) { tree.commit_all_checkpoints_to(completion); }; call_operation(completion, expected_success); } template void revert_all_tree_checkpoints(TreeType& tree, bool expected_success = true) { - auto completion = [&](auto completion) { tree.revert_all_checkpoints(completion); }; + auto completion = [&](auto completion) { tree.revert_all_checkpoints_to(completion); }; + call_operation(completion, expected_success); +} + +template void commit_tree_to_depth(TreeType& tree, uint32_t depth, bool expected_success = true) +{ + auto completion = [&](auto completion) { tree.commit_to_depth(depth, completion); }; + call_operation(completion, expected_success); +} + +template void revert_tree_to_depth(TreeType& tree, uint32_t depth, bool expected_success = true) +{ + auto completion = [&](auto completion) { tree.revert_to_depth(depth, completion); }; call_operation(completion, expected_success); } } // namespace bb::crypto::merkle_tree diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp index 57604598396d..2386799b19a0 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp @@ -265,11 +265,11 @@ WorldStateWrapper::WorldStateWrapper(const Napi::CallbackInfo& info) _dispatcher.register_target( WorldStateMessageType::COMMIT_ALL_CHECKPOINTS, - [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return commit_all_checkpoints(obj, buffer); }); + [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return commit_all_checkpoints_to(obj, buffer); }); _dispatcher.register_target( WorldStateMessageType::REVERT_ALL_CHECKPOINTS, - [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return revert_all_checkpoints(obj, buffer); }); + [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return revert_all_checkpoints_to(obj, buffer); }); _dispatcher.register_target( WorldStateMessageType::COPY_STORES, @@ -843,10 +843,12 @@ bool WorldStateWrapper::checkpoint(msgpack::object& obj, msgpack::sbuffer& buffe TypedMessage request; obj.convert(request); - _ws->checkpoint(request.value.forkId); + uint32_t depth = _ws->checkpoint(request.value.forkId); MsgHeader header(request.header.messageId); - messaging::TypedMessage resp_msg(WorldStateMessageType::CREATE_CHECKPOINT, header, {}); + CheckpointDepthResponse resp_value{ depth }; + messaging::TypedMessage resp_msg( + WorldStateMessageType::CREATE_CHECKPOINT, header, resp_value); msgpack::pack(buffer, resp_msg); return true; @@ -880,12 +882,12 @@ bool WorldStateWrapper::revert_checkpoint(msgpack::object& obj, msgpack::sbuffer return true; } -bool WorldStateWrapper::commit_all_checkpoints(msgpack::object& obj, msgpack::sbuffer& buffer) +bool WorldStateWrapper::commit_all_checkpoints_to(msgpack::object& obj, msgpack::sbuffer& buffer) { - TypedMessage request; + TypedMessage request; obj.convert(request); - _ws->commit_all_checkpoints(request.value.forkId); + _ws->commit_all_checkpoints_to(request.value.forkId, request.value.depth); MsgHeader header(request.header.messageId); messaging::TypedMessage resp_msg(WorldStateMessageType::COMMIT_ALL_CHECKPOINTS, header, {}); @@ -894,12 +896,12 @@ bool WorldStateWrapper::commit_all_checkpoints(msgpack::object& obj, msgpack::sb return true; } -bool WorldStateWrapper::revert_all_checkpoints(msgpack::object& obj, msgpack::sbuffer& buffer) +bool WorldStateWrapper::revert_all_checkpoints_to(msgpack::object& obj, msgpack::sbuffer& buffer) { - TypedMessage request; + TypedMessage request; obj.convert(request); - _ws->revert_all_checkpoints(request.value.forkId); + _ws->revert_all_checkpoints_to(request.value.forkId, request.value.depth); MsgHeader header(request.header.messageId); messaging::TypedMessage resp_msg(WorldStateMessageType::REVERT_ALL_CHECKPOINTS, header, {}); diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp index 02945f8899a9..cd4f0d02e8e1 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp @@ -75,8 +75,8 @@ class WorldStateWrapper : public Napi::ObjectWrap { bool checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); bool commit_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); bool revert_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); - bool commit_all_checkpoints(msgpack::object& obj, msgpack::sbuffer& buffer); - bool revert_all_checkpoints(msgpack::object& obj, msgpack::sbuffer& buffer); + bool commit_all_checkpoints_to(msgpack::object& obj, msgpack::sbuffer& buffer); + bool revert_all_checkpoints_to(msgpack::object& obj, msgpack::sbuffer& buffer); bool copy_stores(msgpack::object& obj, msgpack::sbuffer& buffer); }; diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp index 042537458d5d..8f6b481ad41a 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp @@ -88,6 +88,17 @@ struct ForkIdOnlyRequest { SERIALIZATION_FIELDS(forkId); }; +struct ForkIdWithDepthRequest { + uint64_t forkId; + uint32_t depth; + SERIALIZATION_FIELDS(forkId, depth); +}; + +struct CheckpointDepthResponse { + uint32_t depth; + SERIALIZATION_FIELDS(depth); +}; + struct TreeIdAndRevisionRequest { MerkleTreeId treeId; WorldStateRevision revision; diff --git a/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp b/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp index 93ab14689978..f221e93fcf6f 100644 --- a/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp +++ b/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp @@ -1062,16 +1062,16 @@ bool WorldState::determine_if_synched(std::array& metaRespo return true; } -void WorldState::checkpoint(const uint64_t& forkId) +uint32_t WorldState::checkpoint(const uint64_t& forkId) { Fork::SharedPtr fork = retrieve_fork(forkId); Signal signal(static_cast(fork->_trees.size())); - std::array local; + std::array, NUM_TREES> local; std::mutex mtx; for (auto& [id, tree] : fork->_trees) { std::visit( [&signal, &local, id, &mtx](auto&& wrapper) { - wrapper.tree->checkpoint([&signal, &local, &mtx, id](Response& resp) { + wrapper.tree->checkpoint([&signal, &local, &mtx, id](TypedResponse& resp) { { std::lock_guard lock(mtx); local[id] = std::move(resp); @@ -1087,6 +1087,8 @@ void WorldState::checkpoint(const uint64_t& forkId) throw std::runtime_error(m.message); } } + // All trees have the same checkpoint depth; return it from the first tree's response + return local[0].inner.depth; } void WorldState::commit_checkpoint(const uint64_t& forkId) @@ -1143,7 +1145,7 @@ void WorldState::revert_checkpoint(const uint64_t& forkId) } } -void WorldState::commit_all_checkpoints(const uint64_t& forkId) +void WorldState::commit_all_checkpoints_to(const uint64_t& forkId, uint32_t depth) { Fork::SharedPtr fork = retrieve_fork(forkId); Signal signal(static_cast(fork->_trees.size())); @@ -1151,14 +1153,15 @@ void WorldState::commit_all_checkpoints(const uint64_t& forkId) std::mutex mtx; for (auto& [id, tree] : fork->_trees) { std::visit( - [&signal, &local, id, &mtx](auto&& wrapper) { - wrapper.tree->commit_all_checkpoints([&signal, &local, &mtx, id](Response& resp) { + [&signal, &local, id, &mtx, depth](auto&& wrapper) { + auto callback = [&signal, &local, &mtx, id](Response& resp) { { std::lock_guard lock(mtx); local[id] = std::move(resp); } signal.signal_decrement(); - }); + }; + wrapper.tree->commit_to_depth(depth, callback); }, tree); } @@ -1170,7 +1173,7 @@ void WorldState::commit_all_checkpoints(const uint64_t& forkId) } } -void WorldState::revert_all_checkpoints(const uint64_t& forkId) +void WorldState::revert_all_checkpoints_to(const uint64_t& forkId, uint32_t depth) { Fork::SharedPtr fork = retrieve_fork(forkId); Signal signal(static_cast(fork->_trees.size())); @@ -1178,14 +1181,15 @@ void WorldState::revert_all_checkpoints(const uint64_t& forkId) std::mutex mtx; for (auto& [id, tree] : fork->_trees) { std::visit( - [&signal, &local, id, &mtx](auto&& wrapper) { - wrapper.tree->revert_all_checkpoints([&signal, &local, &mtx, id](Response& resp) { + [&signal, &local, id, &mtx, depth](auto&& wrapper) { + auto callback = [&signal, &local, &mtx, id](Response& resp) { { std::lock_guard lock(mtx); local[id] = std::move(resp); } signal.signal_decrement(); - }); + }; + wrapper.tree->revert_to_depth(depth, callback); }, tree); } diff --git a/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp b/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp index 256c2950600b..d7d8f99d46f5 100644 --- a/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp +++ b/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp @@ -287,11 +287,11 @@ class WorldState { const std::vector& nullifiers, const std::vector& public_writes); - void checkpoint(const uint64_t& forkId); + uint32_t checkpoint(const uint64_t& forkId); void commit_checkpoint(const uint64_t& forkId); void revert_checkpoint(const uint64_t& forkId); - void commit_all_checkpoints(const uint64_t& forkId); - void revert_all_checkpoints(const uint64_t& forkId); + void commit_all_checkpoints_to(const uint64_t& forkId, uint32_t depth); + void revert_all_checkpoints_to(const uint64_t& forkId, uint32_t depth); private: std::shared_ptr _workers; diff --git a/yarn-project/simulator/src/public/hinting_db_sources.ts b/yarn-project/simulator/src/public/hinting_db_sources.ts index 79044c631e64..85f8ab422ccf 100644 --- a/yarn-project/simulator/src/public/hinting_db_sources.ts +++ b/yarn-project/simulator/src/public/hinting_db_sources.ts @@ -410,12 +410,12 @@ export class HintingMerkleWriteOperations implements MerkleTreeWriteOperations { } } - public async createCheckpoint(): Promise { + public async createCheckpoint(): Promise { const actionCounter = this.checkpointActionCounter++; const oldCheckpointId = this.getCurrentCheckpointId(); const treesStateHash = await this.getTreesStateHash(); - await this.db.createCheckpoint(); + const depth = await this.db.createCheckpoint(); this.checkpointStack.push(this.nextCheckpointId++); const newCheckpointId = this.getCurrentCheckpointId(); @@ -424,14 +424,16 @@ export class HintingMerkleWriteOperations implements MerkleTreeWriteOperations { HintingMerkleWriteOperations.log.trace( `[createCheckpoint:${actionCounter}] Checkpoint evolved ${oldCheckpointId} -> ${newCheckpointId} at trees state ${treesStateHash}.`, ); + + return depth; } - public commitAllCheckpoints(): Promise { - throw new Error('commitAllCheckpoints is not supported in HintingMerkleWriteOperations.'); + public commitAllCheckpointsTo(_depth: number): Promise { + throw new Error('commitAllCheckpointsTo is not supported in HintingMerkleWriteOperations.'); } - public revertAllCheckpoints(): Promise { - throw new Error('revertAllCheckpoints is not supported in HintingMerkleWriteOperations.'); + public revertAllCheckpointsTo(_depth: number): Promise { + throw new Error('revertAllCheckpointsTo is not supported in HintingMerkleWriteOperations.'); } public async commitCheckpoint(): Promise { diff --git a/yarn-project/simulator/src/public/public_processor/apps_tests/timeout_race.test.ts b/yarn-project/simulator/src/public/public_processor/apps_tests/timeout_race.test.ts index 3d06b2323916..2d16e26e602f 100644 --- a/yarn-project/simulator/src/public/public_processor/apps_tests/timeout_race.test.ts +++ b/yarn-project/simulator/src/public/public_processor/apps_tests/timeout_race.test.ts @@ -20,7 +20,7 @@ import { GasFees } from '@aztec/stdlib/gas'; import { MerkleTreeId, merkleTreeIds } from '@aztec/stdlib/trees'; import { GlobalVariables } from '@aztec/stdlib/tx'; import { getTelemetryClient } from '@aztec/telemetry-client'; -import { NativeWorldStateService } from '@aztec/world-state'; +import { ForkCheckpoint, NativeWorldStateService } from '@aztec/world-state'; import { jest } from '@jest/globals'; @@ -115,7 +115,7 @@ describe('PublicProcessor C++ Timeout Race Condition', () => { } // Create checkpoint BEFORE simulation (like PublicProcessor does) - await merkleTrees.createCheckpoint(); + const forkCheckpoint = await ForkCheckpoint.new(merkleTrees); // Create transaction that calls the spammer contract const tx = await tester.createTx(admin, [], [{ address: contractAddress, args: callArgs }]); @@ -136,11 +136,8 @@ describe('PublicProcessor C++ Timeout Race Condition', () => { } // BUG - No cancel, C++ continues running during reverts below - // Revert checkpoint - await merkleTrees.revertCheckpoint(); - - // Clean up - await merkleTrees.revertAllCheckpoints(); + // Clean up - revert all changes + await forkCheckpoint.revertToCheckpoint(); // Wait for simulation promise for cleanup await Promise.race([simulationPromise.catch(() => {}), sleep(100)]); diff --git a/yarn-project/simulator/src/public/public_processor/guarded_merkle_tree.ts b/yarn-project/simulator/src/public/public_processor/guarded_merkle_tree.ts index bcbd818a03f0..71133c4a2ebf 100644 --- a/yarn-project/simulator/src/public/public_processor/guarded_merkle_tree.ts +++ b/yarn-project/simulator/src/public/public_processor/guarded_merkle_tree.ts @@ -134,7 +134,7 @@ export class GuardedMerkleTreeOperations implements MerkleTreeWriteOperations { ): Promise<(BlockNumber | undefined)[]> { return this.guardAndPush(() => this.target.getBlockNumbersForLeafIndices(treeId, leafIndices)); } - createCheckpoint(): Promise { + createCheckpoint(): Promise { return this.guardAndPush(() => this.target.createCheckpoint()); } commitCheckpoint(): Promise { @@ -143,11 +143,11 @@ export class GuardedMerkleTreeOperations implements MerkleTreeWriteOperations { revertCheckpoint(): Promise { return this.guardAndPush(() => this.target.revertCheckpoint()); } - commitAllCheckpoints(): Promise { - return this.guardAndPush(() => this.target.commitAllCheckpoints()); + commitAllCheckpointsTo(depth: number): Promise { + return this.guardAndPush(() => this.target.commitAllCheckpointsTo(depth)); } - revertAllCheckpoints(): Promise { - return this.guardAndPush(() => this.target.revertAllCheckpoints()); + revertAllCheckpointsTo(depth: number): Promise { + return this.guardAndPush(() => this.target.revertAllCheckpointsTo(depth)); } findSiblingPaths( treeId: ID, diff --git a/yarn-project/simulator/src/public/public_processor/public_processor.test.ts b/yarn-project/simulator/src/public/public_processor/public_processor.test.ts index 907ee1f907c6..23a019bb6080 100644 --- a/yarn-project/simulator/src/public/public_processor/public_processor.test.ts +++ b/yarn-project/simulator/src/public/public_processor/public_processor.test.ts @@ -91,6 +91,7 @@ describe('public_processor', () => { new PublicDataTreeLeafPreimage(new PublicDataTreeLeaf(Fr.ZERO, Fr.ZERO), /*nextKey=*/ Fr.ZERO, /*nextIndex=*/ 0n), ); merkleTree.getStateReference.mockResolvedValue(stateReference); + merkleTree.createCheckpoint.mockResolvedValue(1); publicTxSimulator.simulate.mockImplementation(() => { return Promise.resolve(mockedEnqueuedCallsResult); @@ -158,7 +159,7 @@ describe('public_processor', () => { expect(failed[0].error).toEqual(new Error(`Failed`)); expect(merkleTree.commitCheckpoint).toHaveBeenCalledTimes(0); - expect(merkleTree.revertCheckpoint).toHaveBeenCalledTimes(1); + expect(merkleTree.revertAllCheckpointsTo).toHaveBeenCalledWith(0); }); it('if a tx errors with assertion failure, public processor returns failed tx with its assertion message', async function () { @@ -173,7 +174,7 @@ describe('public_processor', () => { expect(failed[0].error.message).toMatch(/Forced assertion failure/); expect(merkleTree.commitCheckpoint).toHaveBeenCalledTimes(0); - expect(merkleTree.revertCheckpoint).toHaveBeenCalledTimes(1); + expect(merkleTree.revertAllCheckpointsTo).toHaveBeenCalledWith(0); }); it('does not attempt to overfill a block', async function () { @@ -314,11 +315,45 @@ describe('public_processor', () => { expect(failed[0].error.message).toMatch(/Not enough balance/i); expect(merkleTree.commitCheckpoint).toHaveBeenCalledTimes(0); - expect(merkleTree.revertCheckpoint).toHaveBeenCalledTimes(1); + expect(merkleTree.revertAllCheckpointsTo).toHaveBeenCalledWith(0); expect(merkleTree.sequentialInsert).toHaveBeenCalledTimes(0); }); }); + describe('checkpoint depth', () => { + it('calls revertAllCheckpointsTo with depth on tx failure', async function () { + merkleTree.createCheckpoint.mockResolvedValue(2); + publicTxSimulator.simulate.mockRejectedValue(new Error('Boom')); + + const tx = await mockTxWithPublicCalls(); + const [processed, failed] = await processor.process([tx]); + + expect(processed).toEqual([]); + expect(failed).toHaveLength(1); + expect(merkleTree.revertAllCheckpointsTo).toHaveBeenCalledWith(1); + expect(merkleTree.commitCheckpoint).not.toHaveBeenCalled(); + }); + + it('createCheckpoint is called for each tx', async function () { + const txs = await timesParallel(3, () => mockPrivateOnlyTx()); + + await processor.process(txs); + + expect(merkleTree.createCheckpoint).toHaveBeenCalledTimes(3); + }); + + it('commits checkpoint on successful tx', async function () { + const tx = await mockTxWithPublicCalls(); + + const [processed, failed] = await processor.process([tx]); + + expect(processed).toHaveLength(1); + expect(failed).toEqual([]); + expect(merkleTree.commitCheckpoint).toHaveBeenCalledTimes(1); + expect(merkleTree.revertAllCheckpointsTo).not.toHaveBeenCalled(); + }); + }); + // on uncaught error, public processor clears the tx-level cache entirely it('clears the tx-level cache entirely on uncaught error (like SETUP failure)', async function () { const tx = await mockTxWithPublicCalls(); diff --git a/yarn-project/simulator/src/public/public_processor/public_processor.ts b/yarn-project/simulator/src/public/public_processor/public_processor.ts index 45a3d9e6906e..20ce6fbaa3e4 100644 --- a/yarn-project/simulator/src/public/public_processor/public_processor.ts +++ b/yarn-project/simulator/src/public/public_processor/public_processor.ts @@ -325,14 +325,10 @@ export class PublicProcessor implements Traceable { // 1. At least one outstanding checkpoint that has not been committed (the one created before we processed the tx). // 2. Possible state updates on that checkpoint or any others created during execution. - // First we revert a checkpoint as managed by the ForkCheckpoint. This will revert whatever is the current checkpoint - // which may not be the one originally created by this object. But that is ok, we do this to fulfil the ForkCheckpoint - // lifecycle expectations and ensure it doesn't attempt to commit later on. - await checkpoint.revert(); - - // Now we want to revert any/all remaining checkpoints, destroying any outstanding state updates. - // This needs to be done directly on the underlying fork as the guarded fork has been stopped. - await this.guardedMerkleTree.getUnderlyingFork().revertAllCheckpoints(); + // Revert all checkpoints at or above this checkpoint's depth (inclusive), destroying any outstanding state + // updates from this tx and any nested checkpoints created during execution. This preserves any checkpoints + // created by callers below our depth. + await checkpoint.revertToCheckpoint(); // Revert any contracts added to the DB for the tx. this.contractsDB.revertCheckpoint(); @@ -344,9 +340,9 @@ export class PublicProcessor implements Traceable { break; } - // Roll back state to start of TX before proceeding to next TX - await checkpoint.revert(); - await this.guardedMerkleTree.getUnderlyingFork().revertAllCheckpoints(); + // Roll back state to start of TX before proceeding to next TX. + // Reverts all checkpoints at or above this checkpoint's depth, preserving any caller checkpoints below. + await checkpoint.revertToCheckpoint(); this.contractsDB.revertCheckpoint(); const errorMessage = err instanceof Error || err instanceof AssertionError ? err.message : 'Unknown error'; this.log.warn(`Failed to process tx ${txHash.toString()}: ${errorMessage} ${err?.stack}`); diff --git a/yarn-project/stdlib/src/interfaces/merkle_tree_operations.ts b/yarn-project/stdlib/src/interfaces/merkle_tree_operations.ts index 63ee8e82f9b1..29625e9d4c43 100644 --- a/yarn-project/stdlib/src/interfaces/merkle_tree_operations.ts +++ b/yarn-project/stdlib/src/interfaces/merkle_tree_operations.ts @@ -225,30 +225,20 @@ export interface MerkleTreeReadOperations { } export interface MerkleTreeCheckpointOperations { - /** - * Checkpoints the current fork state - */ - createCheckpoint(): Promise; + /** Checkpoints the current fork state. Returns the depth of the new checkpoint. */ + createCheckpoint(): Promise; - /** - * Commits the current checkpoint - */ + /** Commits the current checkpoint. */ commitCheckpoint(): Promise; - /** - * Reverts the current checkpoint - */ + /** Reverts the current checkpoint. */ revertCheckpoint(): Promise; - /** - * Commits all checkpoints - */ - commitAllCheckpoints(): Promise; + /** Commits all checkpoints above the given depth, leaving checkpoint depth at the given value. */ + commitAllCheckpointsTo(depth: number): Promise; - /** - * Reverts all checkpoints - */ - revertAllCheckpoints(): Promise; + /** Reverts all checkpoints above the given depth, leaving checkpoint depth at the given value. */ + revertAllCheckpointsTo(depth: number): Promise; } export interface MerkleTreeWriteOperations diff --git a/yarn-project/world-state/src/native/fork_checkpoint.test.ts b/yarn-project/world-state/src/native/fork_checkpoint.test.ts new file mode 100644 index 000000000000..787ccfab1221 --- /dev/null +++ b/yarn-project/world-state/src/native/fork_checkpoint.test.ts @@ -0,0 +1,71 @@ +import type { MerkleTreeCheckpointOperations } from '@aztec/stdlib/interfaces/server'; + +import { type MockProxy, mock } from 'jest-mock-extended'; + +import { ForkCheckpoint } from './fork_checkpoint.js'; + +describe('ForkCheckpoint', () => { + let fork: MockProxy; + + beforeEach(() => { + fork = mock(); + fork.createCheckpoint.mockResolvedValue(5); + fork.commitCheckpoint.mockResolvedValue(); + fork.revertCheckpoint.mockResolvedValue(); + }); + + it('stores depth from createCheckpoint', async () => { + const checkpoint = await ForkCheckpoint.new(fork); + expect(checkpoint.depth).toBe(5); + expect(fork.createCheckpoint).toHaveBeenCalledTimes(1); + }); + + it('commit calls commitCheckpoint on fork', async () => { + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.commit(); + expect(fork.commitCheckpoint).toHaveBeenCalledTimes(1); + }); + + it('revert calls revertCheckpoint on fork', async () => { + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.revert(); + expect(fork.revertCheckpoint).toHaveBeenCalledTimes(1); + }); + + it('revertToCheckpoint calls revertAllCheckpointsTo with depth', async () => { + fork.revertAllCheckpointsTo.mockResolvedValue(); + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.revertToCheckpoint(); + expect(fork.revertAllCheckpointsTo).toHaveBeenCalledWith(4); + }); + + it('revertToCheckpoint prevents subsequent commit', async () => { + fork.revertAllCheckpointsTo.mockResolvedValue(); + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.revertToCheckpoint(); + await checkpoint.commit(); + expect(fork.commitCheckpoint).not.toHaveBeenCalled(); + }); + + it('revertToCheckpoint is idempotent', async () => { + fork.revertAllCheckpointsTo.mockResolvedValue(); + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.revertToCheckpoint(); + await checkpoint.revertToCheckpoint(); + expect(fork.revertAllCheckpointsTo).toHaveBeenCalledTimes(1); + }); + + it('commit is idempotent', async () => { + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.commit(); + await checkpoint.commit(); + expect(fork.commitCheckpoint).toHaveBeenCalledTimes(1); + }); + + it('revert is idempotent', async () => { + const checkpoint = await ForkCheckpoint.new(fork); + await checkpoint.revert(); + await checkpoint.revert(); + expect(fork.revertCheckpoint).toHaveBeenCalledTimes(1); + }); +}); diff --git a/yarn-project/world-state/src/native/fork_checkpoint.ts b/yarn-project/world-state/src/native/fork_checkpoint.ts index c4172d689fc3..1672092ff0fe 100644 --- a/yarn-project/world-state/src/native/fork_checkpoint.ts +++ b/yarn-project/world-state/src/native/fork_checkpoint.ts @@ -3,11 +3,14 @@ import type { MerkleTreeCheckpointOperations } from '@aztec/stdlib/interfaces/se export class ForkCheckpoint { private completed = false; - private constructor(private readonly fork: MerkleTreeCheckpointOperations) {} + private constructor( + private readonly fork: MerkleTreeCheckpointOperations, + public readonly depth: number, + ) {} static async new(fork: MerkleTreeCheckpointOperations): Promise { - await fork.createCheckpoint(); - return new ForkCheckpoint(fork); + const depth = await fork.createCheckpoint(); + return new ForkCheckpoint(fork, depth); } async commit(): Promise { @@ -27,4 +30,17 @@ export class ForkCheckpoint { await this.fork.revertCheckpoint(); this.completed = true; } + + /** + * Reverts this checkpoint and any nested checkpoints created on top of it, + * leaving the checkpoint depth at the level it was before this checkpoint was created. + */ + async revertToCheckpoint(): Promise { + if (this.completed) { + return; + } + + await this.fork.revertAllCheckpointsTo(this.depth - 1); + this.completed = true; + } } diff --git a/yarn-project/world-state/src/native/merkle_trees_facade.ts b/yarn-project/world-state/src/native/merkle_trees_facade.ts index b7a107a8eb80..b8d4ca92b3e0 100644 --- a/yarn-project/world-state/src/native/merkle_trees_facade.ts +++ b/yarn-project/world-state/src/native/merkle_trees_facade.ts @@ -319,9 +319,10 @@ export class MerkleTreesForkFacade extends MerkleTreesFacade implements MerkleTr } } - public async createCheckpoint(): Promise { + public async createCheckpoint(): Promise { assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); - await this.instance.call(WorldStateMessageType.CREATE_CHECKPOINT, { forkId: this.revision.forkId }); + const resp = await this.instance.call(WorldStateMessageType.CREATE_CHECKPOINT, { forkId: this.revision.forkId }); + return resp.depth; } public async commitCheckpoint(): Promise { @@ -334,14 +335,20 @@ export class MerkleTreesForkFacade extends MerkleTreesFacade implements MerkleTr await this.instance.call(WorldStateMessageType.REVERT_CHECKPOINT, { forkId: this.revision.forkId }); } - public async commitAllCheckpoints(): Promise { + public async commitAllCheckpointsTo(depth: number): Promise { assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); - await this.instance.call(WorldStateMessageType.COMMIT_ALL_CHECKPOINTS, { forkId: this.revision.forkId }); + await this.instance.call(WorldStateMessageType.COMMIT_ALL_CHECKPOINTS, { + forkId: this.revision.forkId, + depth, + }); } - public async revertAllCheckpoints(): Promise { + public async revertAllCheckpointsTo(depth: number): Promise { assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); - await this.instance.call(WorldStateMessageType.REVERT_ALL_CHECKPOINTS, { forkId: this.revision.forkId }); + await this.instance.call(WorldStateMessageType.REVERT_ALL_CHECKPOINTS, { + forkId: this.revision.forkId, + depth, + }); } } diff --git a/yarn-project/world-state/src/native/message.ts b/yarn-project/world-state/src/native/message.ts index 64f195918c32..edceed40e4b3 100644 --- a/yarn-project/world-state/src/native/message.ts +++ b/yarn-project/world-state/src/native/message.ts @@ -284,6 +284,16 @@ interface WithForkId { forkId: number; } +interface CreateCheckpointResponse { + depth: number; +} + +/** Request to commit/revert all checkpoints down to a target depth. The resulting depth after the operation equals the given depth. */ +interface CheckpointDepthRequest extends WithForkId { + /** The target depth after the operation. All checkpoints above this depth are committed/reverted. */ + depth: number; +} + interface WithWorldStateRevision { revision: WorldStateRevision; } @@ -487,8 +497,8 @@ export type WorldStateRequest = { [WorldStateMessageType.CREATE_CHECKPOINT]: WithForkId; [WorldStateMessageType.COMMIT_CHECKPOINT]: WithForkId; [WorldStateMessageType.REVERT_CHECKPOINT]: WithForkId; - [WorldStateMessageType.COMMIT_ALL_CHECKPOINTS]: WithForkId; - [WorldStateMessageType.REVERT_ALL_CHECKPOINTS]: WithForkId; + [WorldStateMessageType.COMMIT_ALL_CHECKPOINTS]: CheckpointDepthRequest; + [WorldStateMessageType.REVERT_ALL_CHECKPOINTS]: CheckpointDepthRequest; [WorldStateMessageType.COPY_STORES]: CopyStoresRequest; @@ -529,7 +539,7 @@ export type WorldStateResponse = { [WorldStateMessageType.GET_STATUS]: WorldStateStatusSummary; - [WorldStateMessageType.CREATE_CHECKPOINT]: void; + [WorldStateMessageType.CREATE_CHECKPOINT]: CreateCheckpointResponse; [WorldStateMessageType.COMMIT_CHECKPOINT]: void; [WorldStateMessageType.REVERT_CHECKPOINT]: void; [WorldStateMessageType.COMMIT_ALL_CHECKPOINTS]: void; diff --git a/yarn-project/world-state/src/native/native_world_state.test.ts b/yarn-project/world-state/src/native/native_world_state.test.ts index 9677c8698098..ea52aa4a20b3 100644 --- a/yarn-project/world-state/src/native/native_world_state.test.ts +++ b/yarn-project/world-state/src/native/native_world_state.test.ts @@ -1578,7 +1578,8 @@ describe('NativeWorldState', () => { const fork = await ws.fork(); await advanceState(fork); const siblingPathsBefore = await getSiblingPaths(fork); - await fork.createCheckpoint(); + const checkpointDepth = await fork.createCheckpoint(); + expect(checkpointDepth).toEqual(1); await compareState(fork, siblingPathsBefore, true); @@ -1593,7 +1594,7 @@ describe('NativeWorldState', () => { await compareState(fork, siblingPathsAfter, true); await compareState(fork, siblingPathsBefore, false); - await fork.commitAllCheckpoints(); + await fork.commitAllCheckpointsTo(checkpointDepth - 1); await compareState(fork, siblingPathsAfter, true); await compareState(fork, siblingPathsBefore, false); @@ -1604,7 +1605,8 @@ describe('NativeWorldState', () => { const fork = await ws.fork(); await advanceState(fork); const siblingPathsBefore = await getSiblingPaths(fork); - await fork.createCheckpoint(); + const checkpointDepth = await fork.createCheckpoint(); + expect(checkpointDepth).toEqual(1); await compareState(fork, siblingPathsBefore, true); @@ -1612,14 +1614,15 @@ describe('NativeWorldState', () => { let siblingPathsAfter: SiblingPath[] = []; for (let i = 0; i < numCommits; i++) { - await fork.createCheckpoint(); + const newCheckpointDepth = await fork.createCheckpoint(); + expect(newCheckpointDepth).toEqual(checkpointDepth + i + 1); siblingPathsAfter = await advanceState(fork); } await compareState(fork, siblingPathsAfter, true); await compareState(fork, siblingPathsBefore, false); - await fork.revertAllCheckpoints(); + await fork.revertAllCheckpointsTo(checkpointDepth - 1); await compareState(fork, siblingPathsAfter, false); await compareState(fork, siblingPathsBefore, true); @@ -1835,5 +1838,161 @@ describe('NativeWorldState', () => { await fork.close(); }); + + it('createCheckpoint returns depth', async () => { + const fork = await ws.fork(); + expect(await fork.createCheckpoint()).toBe(1); + expect(await fork.createCheckpoint()).toBe(2); + expect(await fork.createCheckpoint()).toBe(3); + await fork.close(); + }); + + it('can commit all to depth', async () => { + const fork = await ws.fork(); + + // Create 3 checkpoints with state changes between each + const initialPaths = await getSiblingPaths(fork); + + await fork.createCheckpoint(); // depth 1 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 2 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 3 + const afterDepth3Paths = await advanceState(fork); + + // Commit depths 3 and 2 into depth 1, leaving depth at 1 + await fork.commitAllCheckpointsTo(1); + + // State should reflect all changes + await compareState(fork, afterDepth3Paths, true); + + // Revert depth 1 — should go back to initial state + await fork.revertCheckpoint(); + await compareState(fork, initialPaths, true); + + await fork.close(); + }); + + it('can revert all to depth', async () => { + const fork = await ws.fork(); + + await fork.createCheckpoint(); // depth 1 + const afterDepth1Paths = await advanceState(fork); + + await fork.createCheckpoint(); // depth 2 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 3 + await advanceState(fork); + + // Revert depths 3 and 2, leaving depth at 1 + await fork.revertAllCheckpointsTo(1); + + // Should be back to after depth 1 state + await compareState(fork, afterDepth1Paths, true); + + // Depth 1 still active — commit it + await fork.commitCheckpoint(); + await compareState(fork, afterDepth1Paths, true); + + await fork.close(); + }); + + it('revert to depth preserves lower checkpoints', async () => { + const fork = await ws.fork(); + + await fork.createCheckpoint(); // depth 1 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 2 + await advanceState(fork); + + // Revert depth 2 only, leaving depth at 1 + await fork.revertAllCheckpointsTo(1); + + // Create new checkpoint at depth 2 with different changes + await fork.createCheckpoint(); // depth 2 again + const newDepth2Paths = await advanceState(fork); + + // Commit depth 2 + await fork.commitCheckpoint(); + + // Commit depth 1 + await fork.commitCheckpoint(); + + // Final state should include the new depth 2 changes + await compareState(fork, newDepth2Paths, true); + + await fork.close(); + }); + + it('commit all with depth 0 commits everything', async () => { + const fork = await ws.fork(); + + await fork.createCheckpoint(); // depth 1 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 2 + const finalPaths = await advanceState(fork); + + // depth 0 commits all checkpoints + await fork.commitAllCheckpointsTo(0); + + // State should reflect all changes + await compareState(fork, finalPaths, true); + + await fork.close(); + }); + + it('revert all with depth 0 reverts everything', async () => { + const fork = await ws.fork(); + const initialPaths = await getSiblingPaths(fork); + + await fork.createCheckpoint(); // depth 1 + await advanceState(fork); + + await fork.createCheckpoint(); // depth 2 + await advanceState(fork); + + // depth 0 reverts all checkpoints + await fork.revertAllCheckpointsTo(0); + + // Should be back to initial state + await compareState(fork, initialPaths, true); + + await fork.close(); + }); + + it('depth is consistent across multiple checkpoint cycles', async () => { + const fork = await ws.fork(); + + // Create checkpoint depth 1 + expect(await fork.createCheckpoint()).toBe(1); + const afterDepth1Paths = await advanceState(fork); + + // Create checkpoint depth 2 + expect(await fork.createCheckpoint()).toBe(2); + await advanceState(fork); + + // Revert depth 2, leaving depth at 1 + await fork.revertAllCheckpointsTo(1); + await compareState(fork, afterDepth1Paths, true); + + // Create new depth 2 + expect(await fork.createCheckpoint()).toBe(2); + const newDepth2Paths = await advanceState(fork); + + // Commit depth 2 + await fork.commitCheckpoint(); + await compareState(fork, newDepth2Paths, true); + + // Commit depth 1 + await fork.commitCheckpoint(); + await compareState(fork, newDepth2Paths, true); + + await fork.close(); + }); }); });