Skip to content

Commit

Permalink
Permit child MastNodeIds to exceed the MastNodeIds of their paren…
Browse files Browse the repository at this point in the history
…ts (#1542)

* feat(core): Permit child `MastNodeId`s to exceed parent ids

* chore: Add changelog

* chore(core): Add doc comments for `node_count`
  • Loading branch information
PhilippGackstatter authored Nov 1, 2024
1 parent 2b96b85 commit f1c0553
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 20 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Rename `EqHash` to `MastNodeFingerprint` and make it `pub` (#1539)
- [BREAKING] `DYN` operation now expects a memory address pointing to the procedure hash (#1535)
- [BREAKING] `DYNCALL` operation fixed, and now expects a memory address pointing to the procedure hash (#1535)

- Permit child `MastNodeId`s to exceed the `MastNodeId`s of their parents (#1542)

#### Fixes

Expand Down
37 changes: 28 additions & 9 deletions core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,22 +525,41 @@ impl MastNodeId {
value: u32,
mast_forest: &MastForest,
) -> Result<Self, DeserializationError> {
if (value as usize) < mast_forest.nodes.len() {
Ok(Self(value))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but only {} nodes in the forest",
value,
mast_forest.nodes.len(),
)))
}
Self::from_u32_with_node_count(value, mast_forest.nodes.len())
}

/// Returns a new [`MastNodeId`] from the given `value` without checking its validity.
pub(crate) fn new_unchecked(value: u32) -> Self {
Self(value)
}

/// Returns a new [`MastNodeId`] with the provided `id`, or an error if `id` is greater or equal
/// to `node_count`. The `node_count` is the total number of nodes in the [`MastForest`] for
/// which this ID is being constructed.
///
/// This function can be used when deserializing an id whose corresponding node is not yet in
/// the forest and [`Self::from_u32_safe`] would fail. For instance, when deserializing the ids
/// referenced by the Join node in this forest:
///
/// ```text
/// [Join(1, 2), Block(foo), Block(bar)]
/// ```
///
/// Since it is less safe than [`Self::from_u32_safe`] and usually not needed it is not public.
pub(super) fn from_u32_with_node_count(
id: u32,
node_count: usize,
) -> Result<Self, DeserializationError> {
if (id as usize) < node_count {
Ok(Self(id))
} else {
Err(DeserializationError::InvalidValue(format!(
"Invalid deserialized MAST node ID '{}', but {} is the number of nodes in the forest",
id, node_count,
)))
}
}

pub fn as_usize(&self) -> usize {
self.0 as usize
}
Expand Down
21 changes: 13 additions & 8 deletions core/src/mast/serialization/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,14 @@ impl MastNodeInfo {
Self { ty, digest: mast_node.digest() }
}

/// Attempts to convert this [`MastNodeInfo`] into a [`MastNode`] for the given `mast_forest`.
///
/// The `node_count` is the total expected number of nodes in the [`MastForest`] **after
/// deserialization**.
pub fn try_into_mast_node(
self,
mast_forest: &mut MastForest,
mast_forest: &MastForest,
node_count: usize,
basic_block_data_decoder: &BasicBlockDataDecoder,
) -> Result<MastNode, DeserializationError> {
match self.ty {
Expand All @@ -59,29 +64,29 @@ impl MastNodeInfo {
Ok(MastNode::Block(block))
},
MastNodeType::Join { left_child_id, right_child_id } => {
let left_child = MastNodeId::from_u32_safe(left_child_id, mast_forest)?;
let right_child = MastNodeId::from_u32_safe(right_child_id, mast_forest)?;
let left_child = MastNodeId::from_u32_with_node_count(left_child_id, node_count)?;
let right_child = MastNodeId::from_u32_with_node_count(right_child_id, node_count)?;
let join = JoinNode::new_unsafe([left_child, right_child], self.digest);
Ok(MastNode::Join(join))
},
MastNodeType::Split { if_branch_id, else_branch_id } => {
let if_branch = MastNodeId::from_u32_safe(if_branch_id, mast_forest)?;
let else_branch = MastNodeId::from_u32_safe(else_branch_id, mast_forest)?;
let if_branch = MastNodeId::from_u32_with_node_count(if_branch_id, node_count)?;
let else_branch = MastNodeId::from_u32_with_node_count(else_branch_id, node_count)?;
let split = SplitNode::new_unsafe([if_branch, else_branch], self.digest);
Ok(MastNode::Split(split))
},
MastNodeType::Loop { body_id } => {
let body_id = MastNodeId::from_u32_safe(body_id, mast_forest)?;
let body_id = MastNodeId::from_u32_with_node_count(body_id, node_count)?;
let loop_node = LoopNode::new_unsafe(body_id, self.digest);
Ok(MastNode::Loop(loop_node))
},
MastNodeType::Call { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let call = CallNode::new_unsafe(callee_id, self.digest);
Ok(MastNode::Call(call))
},
MastNodeType::SysCall { callee_id } => {
let callee_id = MastNodeId::from_u32_safe(callee_id, mast_forest)?;
let callee_id = MastNodeId::from_u32_with_node_count(callee_id, node_count)?;
let syscall = CallNode::new_syscall_unsafe(callee_id, self.digest);
Ok(MastNode::Call(syscall))
},
Expand Down
7 changes: 5 additions & 2 deletions core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ impl Deserializable for MastForest {
for _ in 0..node_count {
let mast_node_info = MastNodeInfo::read_from(source)?;

let node = mast_node_info
.try_into_mast_node(&mut mast_forest, &basic_block_data_decoder)?;
let node = mast_node_info.try_into_mast_node(
&mast_forest,
node_count,
&basic_block_data_decoder,
)?;

mast_forest.add_node(node).map_err(|e| {
DeserializationError::InvalidValue(format!(
Expand Down
45 changes: 45 additions & 0 deletions core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,51 @@ fn serialize_deserialize_all_nodes() {
assert_eq!(mast_forest, deserialized_mast_forest);
}

/// Test that a forest with a node whose child ids are larger than its own id serializes and
/// deserializes successfully.
#[test]
fn mast_forest_serialize_deserialize_with_child_ids_exceeding_parent_id() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let zero = forest.add_block(vec![Operation::U32div], None).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

// Move the Join node before its child nodes and remove the temporary zero node.
forest.nodes.swap_remove(zero.as_usize());

MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
}

/// Test that a forest with a node whose referenced index is >= the max number of nodes in
/// the forest returns an error during deserialization.
#[test]
fn mast_forest_serialize_deserialize_with_overflowing_ids_fails() {
let mut overflow_forest = MastForest::new();
let id0 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id2 = overflow_forest.add_block(vec![Operation::Eqz], None).unwrap();
let id_join = overflow_forest.add_join(id0, id2).unwrap();

let join_node = overflow_forest[id_join].clone();

// Add the Join(0, 2) to this forest which does not have a node with index 2.
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
forest
.add_block(vec![Operation::U32add], Some(vec![(0, deco0), (1, deco1)]))
.unwrap();
forest.add_node(join_node).unwrap();

assert_matches!(
MastForest::read_from_bytes(&forest.to_bytes()),
Err(DeserializationError::InvalidValue(msg)) if msg.contains("number of nodes")
);
}

#[test]
fn mast_forest_invalid_node_id() {
// Hydrate a forest smaller than the second
Expand Down

0 comments on commit f1c0553

Please sign in to comment.