Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 58 additions & 31 deletions acvm-repo/acvm/src/compiler/simulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,16 @@ use acir::{
},
native_types::{Expression, Witness},
};
use std::collections::{BTreeSet, HashMap, HashSet};

#[derive(PartialEq)]
enum BlockStatus {
Initialized,
Used,
}
use std::collections::{BTreeSet, HashSet};

/// Simulate a symbolic solve for a circuit
#[derive(Default)]
pub struct CircuitSimulator {
/// Track the witnesses that can be solved
solvable_witness: HashSet<Witness>,

/// Tells whether a Memory Block is:
/// - Not initialized if not in the map
/// - Initialized if its status is Initialized in the Map
/// - Used, indicating that the block cannot be written anymore.
resolved_blocks: HashMap<BlockId, BlockStatus>,
/// Track whether a [`BlockId`] has been initialized
initialized_blocks: HashSet<BlockId>,
}

impl CircuitSimulator {
Expand Down Expand Up @@ -84,6 +75,11 @@ impl CircuitSimulator {
true
}
Opcode::MemoryOp { block_id, op, predicate } => {
if !self.initialized_blocks.contains(block_id) {
// Memory must be initialized before it can be used.
return false;
}

if !self.can_solve_expression(&op.index) {
return false;
}
Expand All @@ -93,14 +89,12 @@ impl CircuitSimulator {
}
}
if op.operation.is_zero() {
let w = op.value.to_witness().unwrap();
let Some(w) = op.value.to_witness() else {
return false;
};
self.mark_solvable(w);
true
} else {
if let Some(BlockStatus::Used) = self.resolved_blocks.get(block_id) {
// Writing after having used the block should not be allowed
return false;
}
self.try_solve(&Opcode::AssertZero(op.value.clone()))
}
}
Expand All @@ -110,7 +104,7 @@ impl CircuitSimulator {
return false;
}
}
self.resolved_blocks.insert(*block_id, BlockStatus::Initialized);
self.initialized_blocks.insert(*block_id);
true
}
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
Expand Down Expand Up @@ -186,18 +180,7 @@ impl CircuitSimulator {
true
}

BrilligInputs::MemoryArray(block_id) => match self.resolved_blocks.entry(*block_id) {
std::collections::hash_map::Entry::Vacant(_) => false,
std::collections::hash_map::Entry::Occupied(entry)
if *entry.get() == BlockStatus::Used =>
{
true
}
std::collections::hash_map::Entry::Occupied(mut entry) => {
entry.insert(BlockStatus::Used);
true
}
},
BrilligInputs::MemoryArray(block_id) => self.initialized_blocks.contains(block_id),
}
}

Expand All @@ -217,7 +200,11 @@ mod tests {
use acir::{
FieldElement,
acir_field::AcirField,
circuit::{Circuit, ExpressionWidth, Opcode, PublicInputs},
circuit::{
Circuit, ExpressionWidth, Opcode, PublicInputs,
brillig::{BrilligFunctionId, BrilligInputs},
opcodes::{BlockId, BlockType, MemOp},
},
native_types::{Expression, Witness},
};

Expand Down Expand Up @@ -289,4 +276,44 @@ mod tests {

assert!(!CircuitSimulator::default().check_circuit(&disconnected_circuit));
}

#[test]
fn reports_true_when_memory_block_passed_to_brillig_and_then_written_to() {
let circuit = test_circuit(
vec![
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![(FieldElement::one(), Witness(1))],
q_c: FieldElement::zero(),
}),
Opcode::MemoryInit {
block_id: BlockId(0),
init: vec![Witness(0)],
block_type: BlockType::Memory,
},
Opcode::BrilligCall {
id: BrilligFunctionId(0),
inputs: vec![BrilligInputs::MemoryArray(BlockId(0))],
outputs: Vec::new(),
predicate: None,
},
Opcode::MemoryOp {
block_id: BlockId(0),
op: MemOp::read_at_mem_index(
Expression {
mul_terms: Vec::new(),
linear_combinations: Vec::new(),
q_c: FieldElement::one(),
},
Witness(2),
),
predicate: None,
},
],
BTreeSet::from([Witness(1)]),
PublicInputs::default(),
);

assert!(!CircuitSimulator::default().check_circuit(&circuit));
}
}
Loading