Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
self.modified_gates.insert(target, Opcode::AssertZero(expr));
self.deleted_gates.insert(source);
// Update the 'used_witness' map to account for the merge.
let mut witness_list = CircuitSimulator::expr_wit(&expr_use);
witness_list.extend(CircuitSimulator::expr_wit(&expr_define));
let witness_list = CircuitSimulator::expr_wit(&expr_use);
let witness_list = witness_list
.chain(CircuitSimulator::expr_wit(&expr_define));

for w2 in witness_list {
if !circuit_io.contains(&w2) {
used_witness.entry(w2).and_modify(|v| {
Expand Down Expand Up @@ -165,42 +167,43 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
(new_circuit, new_acir_opcode_positions)
}

fn brillig_input_wit(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
fn for_each_brillig_input_wit(&self, input: &BrilligInputs<F>, mut f: impl FnMut(Witness)) {
match input {
BrilligInputs::Single(expr) => {
result.extend(CircuitSimulator::expr_wit(expr));
for witness in CircuitSimulator::expr_wit(expr) {
f(witness);
}
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
result.extend(CircuitSimulator::expr_wit(expr));
for witness in CircuitSimulator::expr_wit(expr) {
f(witness);
}
}
}
BrilligInputs::MemoryArray(block_id) => {
let witnesses = self.resolved_blocks.get(block_id).expect("Unknown block id");
result.extend(witnesses);
for witness in self.resolved_blocks.get(block_id).expect("Unknown block id") {
f(*witness);
}
}
}
result
}

fn brillig_output_wit(&self, output: &BrilligOutputs) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
fn for_each_brillig_output_wit(&self, output: &BrilligOutputs, mut f: impl FnMut(Witness)) {
match output {
BrilligOutputs::Simple(witness) => {
result.insert(*witness);
}
BrilligOutputs::Simple(witness) => f(*witness),
BrilligOutputs::Array(witnesses) => {
result.extend(witnesses);
for witness in witnesses {
f(*witness);
}
}
}
result
}

// Returns the input witnesses used by the opcode
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr).collect(),
Opcode::BlackBoxFuncCall(bb_func) => {
let mut witnesses = bb_func.get_input_witnesses();
witnesses.extend(bb_func.get_outputs_vec());
Expand All @@ -209,9 +212,8 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
}
Opcode::MemoryOp { block_id: _, op } => {
//index and value
let mut witnesses = CircuitSimulator::expr_wit(&op.index);
witnesses.extend(CircuitSimulator::expr_wit(&op.value));
witnesses
let witnesses = CircuitSimulator::expr_wit(&op.index);
witnesses.chain(CircuitSimulator::expr_wit(&op.value)).collect()
}

Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
Expand All @@ -220,15 +222,19 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
Opcode::BrilligCall { inputs, outputs, .. } => {
let mut witnesses = BTreeSet::new();
for i in inputs {
witnesses.extend(self.brillig_input_wit(i));
self.for_each_brillig_input_wit(i, |witness| {
witnesses.insert(witness);
});
}
for i in outputs {
witnesses.extend(self.brillig_output_wit(i));
self.for_each_brillig_output_wit(i, |witness| {
witnesses.insert(witness);
});
}
witnesses
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
let mut witnesses: BTreeSet<Witness> = BTreeSet::from_iter(inputs.iter().copied());
let mut witnesses: BTreeSet<Witness> = inputs.iter().copied().collect();
witnesses.extend(outputs);

if let Some(p) = predicate {
Expand Down
12 changes: 6 additions & 6 deletions acvm-repo/acvm/src/compiler/simulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use acir::{
},
native_types::{Expression, Witness},
};
use std::collections::{BTreeSet, HashSet};
use std::collections::HashSet;

/// Simulate a symbolic solve for a circuit
/// Instead of evaluating witness values from the inputs, like the PWG module is doing,
Expand Down Expand Up @@ -183,11 +183,11 @@ impl CircuitSimulator {
}
}

pub(crate) fn expr_wit<F>(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| [i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
pub(crate) fn expr_wit<F>(expr: &Expression<F>) -> impl Iterator<Item = Witness> {
expr.mul_terms
.iter()
.flat_map(|i| [i.1, i.2])
.chain(expr.linear_combinations.iter().map(|i| i.1))
}
}

Expand Down
Loading