Skip to content

Commit

Permalink
feature: add MastForest::advice_map and load it before the execution
Browse files Browse the repository at this point in the history
When merging forests, merge their advice maps and return error on key
collision.
  • Loading branch information
greenhat committed Nov 19, 2024
1 parent 65a3060 commit 006e03a
Show file tree
Hide file tree
Showing 18 changed files with 222 additions and 24 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 0.12.0 (TBD)

#### Enhancements
- Added `miden_core::mast::MastForest::advice_map` to load it into the advice provider before the `MastForest` execution (#1574).

## 0.11.0 (2024-11-04)

#### Enhancements
Expand Down
10 changes: 10 additions & 0 deletions core/src/advice/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ impl AdviceMap {
pub fn remove(&mut self, key: RpoDigest) -> Option<Vec<Felt>> {
self.0.remove(&key)
}

/// Returns the number of key value pairs in the advice map.
pub fn len(&self) -> usize {
self.0.len()
}

/// Returns true if the advice map is empty.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

impl From<BTreeMap<RpoDigest, Vec<Felt>>> for AdviceMap {
Expand Down
21 changes: 18 additions & 3 deletions core/src/mast/merger/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ impl MastForestMerger {
///
/// It does this in three steps:
///
/// 1. Merge all decorators, which is a case of deduplication and creating a decorator id
/// 1. Merge all advice maps, checking for key collisions.
/// 2. Merge all decorators, which is a case of deduplication and creating a decorator id
/// mapping which contains how existing [`DecoratorId`]s map to [`DecoratorId`]s in the
/// merged forest.
/// 2. Merge all nodes of forests.
/// 3. Merge all nodes of forests.
/// - Similar to decorators, node indices might move during merging, so the merger keeps a
/// node id mapping as it merges nodes.
/// - This is a depth-first traversal over all forests to ensure all children are processed
Expand All @@ -90,10 +91,13 @@ impl MastForestMerger {
/// `replacement` node. Now we can simply add a mapping from the external node to the
/// `replacement` node in our node id mapping which means all nodes that referenced the
/// external node will point to the `replacement` instead.
/// 3. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// 4. Finally, we merge all roots of all forests. Here we map the existing root indices to
/// their potentially new indices in the merged forest and add them to the forest,
/// deduplicating in the process, too.
fn merge_inner(&mut self, forests: Vec<&MastForest>) -> Result<(), MastForestError> {
for other_forest in forests.iter() {
self.merge_advice_map(other_forest)?;
}
for other_forest in forests.iter() {
self.merge_decorators(other_forest)?;
}
Expand Down Expand Up @@ -163,6 +167,17 @@ impl MastForestMerger {
Ok(())
}

fn merge_advice_map(&mut self, other_forest: &MastForest) -> Result<(), MastForestError> {
for (key, value) in other_forest.advice_map.clone().into_iter() {
if self.mast_forest.advice_map().get(&key).is_some() {
return Err(MastForestError::AdviceMapKeyCollisionOnMerge(key));
} else {
self.mast_forest.advice_map_mut().insert(key, value.clone());
}
}
Ok(())
}

fn merge_node(
&mut self,
forest_idx: usize,
Expand Down
53 changes: 52 additions & 1 deletion core/src/mast/merger/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use miden_crypto::{hash::rpo::RpoDigest, ONE};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{Decorator, Operation};
Expand Down Expand Up @@ -794,3 +794,54 @@ fn mast_forest_merge_invalid_decorator_index() {
let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::DecoratorIdOverflow(_, _));
}

/// Tests that forest's advice maps are merged correctly.
#[test]
fn mast_forest_merge_advice_maps_merged() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
let key_b = RpoDigest::new([Felt::new(1), Felt::new(3), Felt::new(2), Felt::new(1)]);
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let (merged, _root_maps) = MastForest::merge([&forest_a, &forest_b]).unwrap();

let merged_advice_map = merged.advice_map();
assert_eq!(merged_advice_map.len(), 2);
assert_eq!(merged_advice_map.get(&key_a).unwrap(), &value_a);
assert_eq!(merged_advice_map.get(&key_b).unwrap(), &value_b);
}

/// Tests that an error is returned when advice maps have a key collision.
#[test]
fn mast_forest_merge_advice_maps_collision() {
let mut forest_a = MastForest::new();
let id_foo = forest_a.add_node(block_foo()).unwrap();
let id_call_a = forest_a.add_call(id_foo).unwrap();
forest_a.make_root(id_call_a);
let key_a = RpoDigest::new([Felt::new(1), Felt::new(2), Felt::new(3), Felt::new(4)]);
let value_a = vec![ONE, ONE];
forest_a.advice_map_mut().insert(key_a, value_a.clone());

let mut forest_b = MastForest::new();
let id_bar = forest_b.add_node(block_bar()).unwrap();
let id_call_b = forest_b.add_call(id_bar).unwrap();
forest_b.make_root(id_call_b);
// The key collides with key_a in the forest_a.
let key_b = key_a;
let value_b = vec![Felt::new(2), Felt::new(2)];
forest_b.advice_map_mut().insert(key_b, value_b.clone());

let err = MastForest::merge([&forest_a, &forest_b]).unwrap_err();
assert_matches!(err, MastForestError::AdviceMapKeyCollisionOnMerge(_));
}
15 changes: 14 additions & 1 deletion core/src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub use node::{
};
use winter_utils::{ByteWriter, DeserializationError, Serializable};

use crate::{Decorator, DecoratorList, Operation};
use crate::{AdviceMap, Decorator, DecoratorList, Operation};

mod serialization;

Expand Down Expand Up @@ -50,6 +50,9 @@ pub struct MastForest {

/// All the decorators included in the MAST forest.
decorators: Vec<Decorator>,

/// Advice map to be loaded into the VM prior to executing procedures from this MAST forest.
advice_map: AdviceMap,
}

// ------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -463,6 +466,14 @@ impl MastForest {
pub fn nodes(&self) -> &[MastNode] {
&self.nodes
}

pub fn advice_map(&self) -> &AdviceMap {
&self.advice_map
}

pub fn advice_map_mut(&mut self) -> &mut AdviceMap {
&mut self.advice_map
}
}

impl Index<MastNodeId> for MastForest {
Expand Down Expand Up @@ -689,4 +700,6 @@ pub enum MastForestError {
EmptyBasicBlock,
#[error("decorator root of child with node id {0} is missing but required for fingerprint computation")]
ChildFingerprintMissing(MastNodeId),
#[error("advice map key already exists when merging forests: {0}")]
AdviceMapKeyCollisionOnMerge(RpoDigest),
}
5 changes: 4 additions & 1 deletion core/src/mast/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use string_table::{StringTable, StringTableBuilder};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

use super::{DecoratorId, MastForest, MastNode, MastNodeId};
use crate::AdviceMap;

mod decorator;

Expand Down Expand Up @@ -71,7 +72,7 @@ const MAGIC: &[u8; 5] = b"MAST\0";
/// If future modifications are made to this format, the version should be incremented by 1. A
/// version of `[255, 255, 255]` is reserved for future extensions that require extending the
/// version field itself, but should be considered invalid for now.
const VERSION: [u8; 3] = [0, 0, 0];
const VERSION: [u8; 3] = [0, 0, 1];

// MAST FOREST SERIALIZATION/DESERIALIZATION
// ================================================================================================
Expand Down Expand Up @@ -161,6 +162,7 @@ impl Serializable for MastForest {
// Write "before enter" and "after exit" decorators
before_enter_decorators.write_into(target);
after_exit_decorators.write_into(target);
self.advice_map.write_into(target);
}
}

Expand Down Expand Up @@ -256,6 +258,7 @@ impl Deserializable for MastForest {
let node_id = MastNodeId::from_u32_safe(node_id, &mast_forest)?;
mast_forest.set_after_exit(node_id, decorator_ids);
}
mast_forest.advice_map = AdviceMap::read_from(source)?;

Ok(mast_forest)
}
Expand Down
21 changes: 20 additions & 1 deletion core/src/mast/serialization/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::{string::ToString, sync::Arc};

use miden_crypto::{hash::rpo::RpoDigest, Felt};
use miden_crypto::{hash::rpo::RpoDigest, Felt, ONE};

use super::*;
use crate::{
Expand Down Expand Up @@ -435,3 +435,22 @@ fn mast_forest_invalid_node_id() {
// Validate normal operations
forest.add_join(first, second).unwrap();
}

/// Test `MastForest::advice_map` serialization and deserialization.
#[test]
fn mast_forest_serialize_deserialize_advice_map() {
let mut forest = MastForest::new();
let deco0 = forest.add_decorator(Decorator::Trace(0)).unwrap();
let deco1 = forest.add_decorator(Decorator::Trace(1)).unwrap();
let first = forest.add_block(vec![Operation::U32add], Some(vec![(0, deco0)])).unwrap();
let second = forest.add_block(vec![Operation::U32and], Some(vec![(1, deco1)])).unwrap();
forest.add_join(first, second).unwrap();

let key = RpoDigest::new([ONE, ONE, ONE, ONE]);
let value = vec![ONE, ONE];

forest.advice_map_mut().insert(key, value);

let parsed = MastForest::read_from_bytes(&forest.to_bytes()).unwrap();
assert_eq!(forest.advice_map, parsed.advice_map);
}
2 changes: 1 addition & 1 deletion miden/benches/program_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn program_execution(c: &mut Criterion) {

let stdlib = StdLibrary::default();
let mut host = DefaultHost::default();
host.load_mast_forest(stdlib.as_ref().mast_forest().clone());
host.load_mast_forest(stdlib.as_ref().mast_forest().clone()).unwrap();

group.bench_function("sha256", |bench| {
let source = "
Expand Down
2 changes: 1 addition & 1 deletion miden/src/examples/blake3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub fn get_example(n: usize) -> Example<DefaultHost<MemAdviceProvider>> {
);

let mut host = DefaultHost::default();
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone()).unwrap();

let stack_inputs =
StackInputs::try_from_ints(INITIAL_HASH_VALUE.iter().map(|&v| v as u64)).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion miden/src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ fn execute(
let stack_inputs = StackInputs::default();
let mut host = DefaultHost::default();
for library in provided_libraries {
host.load_mast_forest(library.mast_forest().clone());
host.load_mast_forest(library.mast_forest().clone())
.map_err(|err| format!("{err}"))?;
}

let state_iter = processor::execute_iter(&program, stack_inputs, host);
Expand Down
3 changes: 2 additions & 1 deletion miden/src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ impl Analyze {
// fetch the stack and program inputs from the arguments
let stack_inputs = input_data.parse_stack_inputs().map_err(Report::msg)?;
let mut host = DefaultHost::new(input_data.parse_advice_provider().map_err(Report::msg)?);
host.load_mast_forest(StdLibrary::default().mast_forest().clone());
host.load_mast_forest(StdLibrary::default().mast_forest().clone())
.into_diagnostic()?;

let execution_details: ExecutionDetails = analyze(program.as_str(), stack_inputs, host)
.expect("Could not retrieve execution details");
Expand Down
53 changes: 53 additions & 0 deletions miden/tests/integration/exec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use assembly::Assembler;
use miden_vm::DefaultHost;
use processor::{ExecutionOptions, MastForest};
use prover::{Digest, StackInputs};
use vm_core::{assert_matches, Program, ONE};

#[test]
fn advice_map_loaded_before_execution() {
let source = "\
begin
push.1.1.1.1
adv.push_mapval
dropw
end";

// compile and execute program
let program_without_advice_map: Program =
Assembler::default().assemble_program(source).unwrap();

// Test `processor::execute` fails if no advice map provided with the program
let mut host = DefaultHost::default();
match processor::execute(
&program_without_advice_map,
StackInputs::default(),
&mut host,
ExecutionOptions::default(),
) {
Ok(_) => panic!("Expected error"),
Err(e) => {
assert_matches!(e, prover::ExecutionError::AdviceMapKeyNotFound(_));
},
}

// Test `processor::execute` works if advice map provided with the program
let mast_forest: MastForest = (**program_without_advice_map.mast_forest()).clone();

let key = Digest::new([ONE, ONE, ONE, ONE]);
let value = vec![ONE, ONE];

let mut mast_forest = mast_forest.clone();
mast_forest.advice_map_mut().insert(key, value);
let program_with_advice_map =
Program::new(mast_forest.into(), program_without_advice_map.entrypoint());

let mut host = DefaultHost::default();
processor::execute(
&program_with_advice_map,
StackInputs::default(),
&mut host,
ExecutionOptions::default(),
)
.unwrap();
}
1 change: 1 addition & 0 deletions miden/tests/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use test_utils::{build_op_test, build_test};

mod air;
mod cli;
mod exec;
mod exec_iters;
mod flow_control;
mod operations;
Expand Down
5 changes: 5 additions & 0 deletions processor/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use super::{
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExecutionError {
AdviceMapKeyNotFound(Word),
AdviceMapKeyAlreadyPresent(Word),
AdviceStackReadFailed(RowIndex),
CallerNotInSyscall,
CircularExternalNode(Digest),
Expand Down Expand Up @@ -96,6 +97,10 @@ impl Display for ExecutionError {
let hex = to_hex(Felt::elements_as_bytes(key));
write!(f, "Value for key {hex} not present in the advice map")
},
AdviceMapKeyAlreadyPresent(key) => {
let hex = to_hex(Felt::elements_as_bytes(key));
write!(f, "Value for key {hex} already present in the advice map")
},
AdviceStackReadFailed(step) => write!(f, "Advice stack read failed at step {step}"),
CallerNotInSyscall => {
write!(f, "Instruction `caller` used outside of kernel context")
Expand Down
13 changes: 11 additions & 2 deletions processor/src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,17 @@ where
}
}

pub fn load_mast_forest(&mut self, mast_forest: Arc<MastForest>) {
self.store.insert(mast_forest)
pub fn load_mast_forest(&mut self, mast_forest: Arc<MastForest>) -> Result<(), ExecutionError> {
// Load the MAST's advice data into the advice provider.
for (digest, values) in mast_forest.advice_map().clone().into_iter() {
if self.adv_provider.get_mapped_values(&digest).is_some() {
return Err(ExecutionError::AdviceMapKeyAlreadyPresent(digest.into()));
} else {
self.adv_provider.insert_into_map(digest.into(), values);
}
}
self.store.insert(mast_forest);
Ok(())
}

#[cfg(any(test, feature = "testing"))]
Expand Down
12 changes: 12 additions & 0 deletions processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ where
return Err(ExecutionError::ProgramAlreadyExecuted);
}

// Load the program's advice data into the advice provider
for (digest, values) in program.mast_forest().advice_map().clone().into_iter() {
if self.host.borrow().advice_provider().get_mapped_values(&digest).is_some() {
return Err(ExecutionError::AdviceMapKeyAlreadyPresent(digest.into()));
} else {
self.host
.borrow_mut()
.advice_provider_mut()
.insert_into_map(digest.into(), values);
}
}

self.execute_mast_node(program.entrypoint(), &program.mast_forest().clone())?;

self.stack.build_stack_outputs()
Expand Down
Loading

0 comments on commit 006e03a

Please sign in to comment.