diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 1a183e95c87..158f9a7451d 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -126,8 +126,10 @@ impl MergeExpressionsOptimizer { self.modified_gates.insert(target, Opcode::AssertZero(expr)); self.deleted_gates.insert(source); // Update the 'used_witness' map to account for the merge. - let mut witness_list = CircuitSimulator::expr_wit(&expr_use); - witness_list.extend(CircuitSimulator::expr_wit(&expr_define)); + let witness_list = CircuitSimulator::expr_wit(&expr_use); + let witness_list = witness_list + .chain(CircuitSimulator::expr_wit(&expr_define)); + for w2 in witness_list { if !circuit_io.contains(&w2) { used_witness.entry(w2).and_modify(|v| { @@ -165,42 +167,43 @@ impl MergeExpressionsOptimizer { (new_circuit, new_acir_opcode_positions) } - fn brillig_input_wit(&self, input: &BrilligInputs) -> BTreeSet { - let mut result = BTreeSet::new(); + fn for_each_brillig_input_wit(&self, input: &BrilligInputs, mut f: impl FnMut(Witness)) { match input { BrilligInputs::Single(expr) => { - result.extend(CircuitSimulator::expr_wit(expr)); + for witness in CircuitSimulator::expr_wit(expr) { + f(witness); + } } BrilligInputs::Array(exprs) => { for expr in exprs { - result.extend(CircuitSimulator::expr_wit(expr)); + for witness in CircuitSimulator::expr_wit(expr) { + f(witness); + } } } BrilligInputs::MemoryArray(block_id) => { - let witnesses = self.resolved_blocks.get(block_id).expect("Unknown block id"); - result.extend(witnesses); + for witness in self.resolved_blocks.get(block_id).expect("Unknown block id") { + f(*witness); + } } } - result } - fn brillig_output_wit(&self, output: &BrilligOutputs) -> BTreeSet { - let mut result = BTreeSet::new(); + fn for_each_brillig_output_wit(&self, output: &BrilligOutputs, mut f: impl FnMut(Witness)) { match output { - BrilligOutputs::Simple(witness) => { - result.insert(*witness); - } + BrilligOutputs::Simple(witness) => f(*witness), BrilligOutputs::Array(witnesses) => { - result.extend(witnesses); + for witness in witnesses { + f(*witness); + } } } - result } // Returns the input witnesses used by the opcode fn witness_inputs(&self, opcode: &Opcode) -> BTreeSet { match opcode { - Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr), + Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr).collect(), Opcode::BlackBoxFuncCall(bb_func) => { let mut witnesses = bb_func.get_input_witnesses(); witnesses.extend(bb_func.get_outputs_vec()); @@ -209,9 +212,8 @@ impl MergeExpressionsOptimizer { } Opcode::MemoryOp { block_id: _, op } => { //index and value - let mut witnesses = CircuitSimulator::expr_wit(&op.index); - witnesses.extend(CircuitSimulator::expr_wit(&op.value)); - witnesses + let witnesses = CircuitSimulator::expr_wit(&op.index); + witnesses.chain(CircuitSimulator::expr_wit(&op.value)).collect() } Opcode::MemoryInit { block_id: _, init, block_type: _ } => { @@ -220,15 +222,19 @@ impl MergeExpressionsOptimizer { Opcode::BrilligCall { inputs, outputs, .. } => { let mut witnesses = BTreeSet::new(); for i in inputs { - witnesses.extend(self.brillig_input_wit(i)); + self.for_each_brillig_input_wit(i, |witness| { + witnesses.insert(witness); + }); } for i in outputs { - witnesses.extend(self.brillig_output_wit(i)); + self.for_each_brillig_output_wit(i, |witness| { + witnesses.insert(witness); + }); } witnesses } Opcode::Call { id: _, inputs, outputs, predicate } => { - let mut witnesses: BTreeSet = BTreeSet::from_iter(inputs.iter().copied()); + let mut witnesses: BTreeSet = inputs.iter().copied().collect(); witnesses.extend(outputs); if let Some(p) = predicate { diff --git a/acvm-repo/acvm/src/compiler/simulator.rs b/acvm-repo/acvm/src/compiler/simulator.rs index 47e62c2409c..2917080a93d 100644 --- a/acvm-repo/acvm/src/compiler/simulator.rs +++ b/acvm-repo/acvm/src/compiler/simulator.rs @@ -7,7 +7,7 @@ use acir::{ }, native_types::{Expression, Witness}, }; -use std::collections::{BTreeSet, HashSet}; +use std::collections::HashSet; /// Simulate a symbolic solve for a circuit /// Instead of evaluating witness values from the inputs, like the PWG module is doing, @@ -183,11 +183,11 @@ impl CircuitSimulator { } } - pub(crate) fn expr_wit(expr: &Expression) -> BTreeSet { - let mut result = BTreeSet::new(); - result.extend(expr.mul_terms.iter().flat_map(|i| [i.1, i.2])); - result.extend(expr.linear_combinations.iter().map(|i| i.1)); - result + pub(crate) fn expr_wit(expr: &Expression) -> impl Iterator { + expr.mul_terms + .iter() + .flat_map(|i| [i.1, i.2]) + .chain(expr.linear_combinations.iter().map(|i| i.1)) } }