diff --git a/acvm-repo/acvm/src/compiler/simulator.rs b/acvm-repo/acvm/src/compiler/simulator.rs index 96134926f5e..b21b8374624 100644 --- a/acvm-repo/acvm/src/compiler/simulator.rs +++ b/acvm-repo/acvm/src/compiler/simulator.rs @@ -7,13 +7,7 @@ use acir::{ }, native_types::{Expression, Witness}, }; -use std::collections::{BTreeSet, HashMap, HashSet}; - -#[derive(PartialEq)] -enum BlockStatus { - Initialized, - Used, -} +use std::collections::{BTreeSet, HashSet}; /// Simulate a symbolic solve for a circuit #[derive(Default)] @@ -21,11 +15,8 @@ pub struct CircuitSimulator { /// Track the witnesses that can be solved solvable_witness: HashSet, - /// Tells whether a Memory Block is: - /// - Not initialized if not in the map - /// - Initialized if its status is Initialized in the Map - /// - Used, indicating that the block cannot be written anymore. - resolved_blocks: HashMap, + /// Track whether a [`BlockId`] has been initialized + initialized_blocks: HashSet, } impl CircuitSimulator { @@ -84,6 +75,11 @@ impl CircuitSimulator { true } Opcode::MemoryOp { block_id, op, predicate } => { + if !self.initialized_blocks.contains(block_id) { + // Memory must be initialized before it can be used. + return false; + } + if !self.can_solve_expression(&op.index) { return false; } @@ -93,14 +89,12 @@ impl CircuitSimulator { } } if op.operation.is_zero() { - let w = op.value.to_witness().unwrap(); + let Some(w) = op.value.to_witness() else { + return false; + }; self.mark_solvable(w); true } else { - if let Some(BlockStatus::Used) = self.resolved_blocks.get(block_id) { - // Writing after having used the block should not be allowed - return false; - } self.try_solve(&Opcode::AssertZero(op.value.clone())) } } @@ -110,7 +104,7 @@ impl CircuitSimulator { return false; } } - self.resolved_blocks.insert(*block_id, BlockStatus::Initialized); + self.initialized_blocks.insert(*block_id); true } Opcode::BrilligCall { id: _, inputs, outputs, predicate } => { @@ -186,18 +180,7 @@ impl CircuitSimulator { true } - BrilligInputs::MemoryArray(block_id) => match self.resolved_blocks.entry(*block_id) { - std::collections::hash_map::Entry::Vacant(_) => false, - std::collections::hash_map::Entry::Occupied(entry) - if *entry.get() == BlockStatus::Used => - { - true - } - std::collections::hash_map::Entry::Occupied(mut entry) => { - entry.insert(BlockStatus::Used); - true - } - }, + BrilligInputs::MemoryArray(block_id) => self.initialized_blocks.contains(block_id), } } @@ -217,7 +200,11 @@ mod tests { use acir::{ FieldElement, acir_field::AcirField, - circuit::{Circuit, ExpressionWidth, Opcode, PublicInputs}, + circuit::{ + Circuit, ExpressionWidth, Opcode, PublicInputs, + brillig::{BrilligFunctionId, BrilligInputs}, + opcodes::{BlockId, BlockType, MemOp}, + }, native_types::{Expression, Witness}, }; @@ -289,4 +276,44 @@ mod tests { assert!(!CircuitSimulator::default().check_circuit(&disconnected_circuit)); } + + #[test] + fn reports_true_when_memory_block_passed_to_brillig_and_then_written_to() { + let circuit = test_circuit( + vec![ + Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![(FieldElement::one(), Witness(1))], + q_c: FieldElement::zero(), + }), + Opcode::MemoryInit { + block_id: BlockId(0), + init: vec![Witness(0)], + block_type: BlockType::Memory, + }, + Opcode::BrilligCall { + id: BrilligFunctionId(0), + inputs: vec![BrilligInputs::MemoryArray(BlockId(0))], + outputs: Vec::new(), + predicate: None, + }, + Opcode::MemoryOp { + block_id: BlockId(0), + op: MemOp::read_at_mem_index( + Expression { + mul_terms: Vec::new(), + linear_combinations: Vec::new(), + q_c: FieldElement::one(), + }, + Witness(2), + ), + predicate: None, + }, + ], + BTreeSet::from([Witness(1)]), + PublicInputs::default(), + ); + + assert!(!CircuitSimulator::default().check_circuit(&circuit)); + } }