Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(SMT): reverse mutations generation, mutations serialization #355

Open
wants to merge 3 commits into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# TBD

- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).

## 0.13.0 (2024-11-24)

- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
Expand Down
10 changes: 5 additions & 5 deletions src/merkle/smt/full/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl Smt {
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<(), MerkleError> {
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}

Expand Down Expand Up @@ -344,12 +344,12 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}

fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node)
}

fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}

fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
Expand Down
117 changes: 111 additions & 6 deletions src/merkle/smt/full/tests.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use alloc::vec::Vec;
use std::collections::BTreeMap;

use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
merkle::{
smt::{NodeMutation, SparseMerkleTree},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE,
};

// SMT
// --------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -412,21 +415,49 @@ fn test_prospective_insertion() {

let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
smt.apply_mutations(mutations).unwrap();
let revert = smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
revert.node_mutations,
smt.inner_nodes.iter().map(|(key, _)| (*key, NodeMutation::Removal)).collect(),
"reverse mutations inner nodes did not match"
);

let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = smt.apply_mutations(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);

// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);

// Test batch updates, and that the order doesn't matter.
let pairs =
Expand All @@ -437,8 +468,16 @@ fn test_prospective_insertion() {
root_empty,
"prospective root for batch removal did not match actual root",
);
smt.apply_mutations(mutations).unwrap();
let old_root = smt.root();
let revert = smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);

let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
Expand All @@ -447,6 +486,72 @@ fn test_prospective_insertion() {
assert_eq!(smt.root(), root_3);
}

#[test]
fn test_mutations_revert() {
let mut smt = Smt::default();

let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);

let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];

smt.insert(key_1, value_1);
smt.insert(key_2, value_2);

let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);

let original = smt.clone();

let revert = smt.apply_mutations(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");

let _ = smt.apply_mutations(revert).unwrap();

assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
}

#[test]
fn test_mutation_set_serialization() {
let mut smt = Smt::default();

let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);

let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];

smt.insert(key_1, value_1);
smt.insert(key_2, value_2);

let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);

let serialized = mutations.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();

assert_eq!(deserialized, mutations, "deserialized mutations did not match original");

let revert = smt.apply_mutations(mutations).unwrap();

let serialized = revert.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();

assert_eq!(deserialized, revert, "deserialized mutations did not match original");
}

/// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() {
Expand Down
111 changes: 102 additions & 9 deletions src/merkle/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use alloc::{collections::BTreeMap, vec::Vec};
use core::mem;

use num::Integer;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{
Expand Down Expand Up @@ -149,7 +150,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
self.remove_inner_node(index);
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
Expand Down Expand Up @@ -256,7 +257,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
}

/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
/// this tree. Return reverse mutation set.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
Expand All @@ -266,7 +267,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
Expand All @@ -287,20 +288,41 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
});
}

let mut reverse_mutations = BTreeMap::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we rely on the computation of mutations set, that it didn't generate useless update where new value is the same as previous (actually it does generate unnecessary changes in mutations set). Otherwise we would need to clone node and compare it with old_node.

} else {
reverse_mutations.insert(index, Removal);
}
},
}
}

let mut reverse_pairs = BTreeMap::new();
for (key, value) in new_pairs {
self.insert_value(key, value);
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we also rely on the computation of mutations set, that it didn't generate useless update where new value is the same as previous (actually it does generate unnecessary changes in mutations set). Otherwise we would need to clone value and compare it with old_value.

} else {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
}
}

self.set_root(new_root);

Ok(())
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}

// REQUIRED METHODS
Expand All @@ -326,10 +348,10 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;

/// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;

/// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex);
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;

/// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
Expand Down Expand Up @@ -606,8 +628,78 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
}
}

// SERIALIZATION
// ================================================================================================

impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
}
}

impl Deserializable for InnerNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = source.read()?;
let right = source.read()?;

Ok(Self { left, right })
}
}

impl Serializable for NodeMutation {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
match self {
NodeMutation::Removal => target.write_bool(false),
NodeMutation::Addition(inner_node) => {
target.write_bool(true);
inner_node.write_into(target);
},
}
}
}

impl Deserializable for NodeMutation {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
if source.read_bool()? {
let inner_node = source.read()?;
return Ok(NodeMutation::Addition(inner_node));
}

Ok(NodeMutation::Removal)
}
}

impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
}
}

impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;

Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}

// SUBTREES
// ================================================================================================

/// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;

Expand Down Expand Up @@ -797,5 +889,6 @@ pub fn build_subtree_for_bench(

// TESTS
// ================================================================================================

#[cfg(test)]
mod tests;
Loading
Loading