diff --git a/acvm-repo/acvm/src/compiler/optimizers/general.rs b/acvm-repo/acvm/src/compiler/optimizers/general.rs index 0802f33185c..baeee12a14b 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/general.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/general.rs @@ -12,22 +12,13 @@ pub(crate) struct GeneralOptimizer; impl GeneralOptimizer { pub(crate) fn optimize(opcode: Expression) -> Expression { // XXX: Perhaps this optimization can be done on the fly - let opcode = remove_zero_coefficients(opcode); let opcode = simplify_mul_terms(opcode); simplify_linear_terms(opcode) } } -// Remove all terms with zero as a coefficient -fn remove_zero_coefficients(mut opcode: Expression) -> Expression { - // Check the mul terms - opcode.mul_terms.retain(|(scale, _, _)| !scale.is_zero()); - // Check the linear combination terms - opcode.linear_combinations.retain(|(scale, _)| !scale.is_zero()); - opcode -} - -// Simplifies all mul terms with the same bi-variate variables +// Simplifies all mul terms with the same bi-variate variables while also removing +// terms that end up with a zero coefficient. fn simplify_mul_terms(mut gate: Expression) -> Expression { let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new(); @@ -40,11 +31,16 @@ fn simplify_mul_terms(mut gate: Expression) -> Expression { *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale; } - gate.mul_terms = hash_map.into_iter().map(|((w_l, w_r), scale)| (scale, w_l, w_r)).collect(); + gate.mul_terms = hash_map + .into_iter() + .filter(|(_, scale)| !scale.is_zero()) + .map(|((w_l, w_r), scale)| (scale, w_l, w_r)) + .collect(); gate } -// Simplifies all linear terms with the same variables +// Simplifies all linear terms with the same variables while also removing +// terms that end up with a zero coefficient. fn simplify_linear_terms(mut gate: Expression) -> Expression { let mut hash_map: IndexMap = IndexMap::new(); @@ -60,3 +56,161 @@ fn simplify_linear_terms(mut gate: Expression) -> Expression .collect(); gate } + +#[cfg(test)] +mod tests { + use acir::{ + FieldElement, + circuit::{Circuit, Opcode}, + }; + + use crate::{assert_circuit_snapshot, compiler::optimizers::GeneralOptimizer}; + + fn optimize(circuit: Circuit) -> Circuit { + let opcodes = circuit + .clone() + .opcodes + .into_iter() + .map(|opcode| { + if let Opcode::AssertZero(arith_expr) = opcode { + Opcode::AssertZero(GeneralOptimizer::optimize(arith_expr)) + } else { + opcode + } + }) + .collect(); + let mut optimized_circuit = circuit; + optimized_circuit.opcodes = opcodes; + optimized_circuit + } + + #[test] + fn removes_zero_coefficients_from_mul_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + + // The first multiplication should be removed + EXPR [ (0, _0, _1) (1, _0, _1) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (1, _0, _1) 0 ] + "); + } + + #[test] + fn removes_zero_coefficients_from_linear_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + + // The first linear combination should be removed + EXPR [ (0, _0) (1, _1) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (1, _1) 0 ] + "); + } + + #[test] + fn simplifies_mul_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + + // There are all mul terms with the same variables so we should end up with just one + // that is the sum of all the coefficients + EXPR [ (2, _0, _1) (3, _1, _0) (4, _0, _1) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (9, _0, _1) 0 ] + "); + } + + #[test] + fn removes_zero_coefficients_after_simplifying_mul_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (2, _0, _1) (3, _1, _0) (-5, _0, _1) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ 0 ] + "); + } + + #[test] + fn simplifies_linear_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + + // These are all linear terms with the same variable so we should end up with just one + // that is the sum of all the coefficients + EXPR [ (1, _0) (2, _0) (3, _0) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (6, _0) 0 ] + "); + } + + #[test] + fn removes_zero_coefficients_after_simplifying_linear_terms() { + let src = " + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (1, _0) (2, _0) (-3, _0) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = optimize(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _1 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ 0 ] + "); + } +} diff --git a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs index 31d1b72e68d..b8a37922a0e 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs @@ -279,20 +279,13 @@ impl MergeExpressionsOptimizer { #[cfg(test)] mod tests { - use crate::compiler::{CircuitSimulator, optimizers::MergeExpressionsOptimizer}; - use acir::{ - FieldElement, - acir_field::AcirField, - circuit::{ - Circuit, ExpressionWidth, Opcode, PublicInputs, - brillig::{BrilligFunctionId, BrilligOutputs}, - opcodes::{BlackBoxFuncCall, FunctionInput}, - }, - native_types::{Expression, Witness}, + use crate::{ + assert_circuit_snapshot, + compiler::{CircuitSimulator, optimizers::MergeExpressionsOptimizer}, }; - use std::collections::BTreeSet; + use acir::{FieldElement, circuit::Circuit}; - fn must_check_circuit(circuit: Circuit) -> Circuit { + fn merge_expressions(circuit: Circuit) -> Circuit { assert!(CircuitSimulator::default().check_circuit(&circuit).is_none()); let mut merge_optimizer = MergeExpressionsOptimizer::new(); let acir_opcode_positions = vec![0; 20]; @@ -300,6 +293,7 @@ mod tests { merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions); let mut optimized_circuit = circuit; optimized_circuit.opcodes = opcodes; + // check that the circuit is still valid after optimization assert!(CircuitSimulator::default().check_circuit(&optimized_circuit).is_none()); optimized_circuit @@ -307,142 +301,60 @@ mod tests { #[test] fn does_not_eliminate_witnesses_returned_from_brillig() { - let opcodes = vec![ - Opcode::BrilligCall { - id: BrilligFunctionId::default(), - inputs: Vec::new(), - outputs: vec![BrilligOutputs::Simple(Witness(1))], - predicate: None, - }, - Opcode::AssertZero(Expression { - mul_terms: Vec::new(), - linear_combinations: vec![ - (FieldElement::from(2_u128), Witness(0)), - (FieldElement::from(3_u128), Witness(1)), - (FieldElement::from(1_u128), Witness(2)), - ], - q_c: FieldElement::one(), - }), - Opcode::AssertZero(Expression { - mul_terms: Vec::new(), - linear_combinations: vec![ - (FieldElement::from(2_u128), Witness(0)), - (FieldElement::from(2_u128), Witness(1)), - (FieldElement::from(1_u128), Witness(5)), - ], - q_c: FieldElement::one(), - }), - ]; - - let mut private_parameters = BTreeSet::new(); - private_parameters.insert(Witness(0)); - - let circuit = Circuit { - current_witness_index: 1, - expression_width: ExpressionWidth::Bounded { width: 4 }, - opcodes, - private_parameters, - public_parameters: PublicInputs::default(), - return_values: PublicInputs::default(), - assert_messages: Default::default(), - }; - must_check_circuit(circuit); + let src = " + current witness index : _1 + private parameters indices : [_0] + public parameters indices : [] + return value indices : [] + BRILLIG CALL func 0: inputs: [], outputs: [_1] + EXPR [ (2, _0) (3, _1) (1, _2) 1 ] + EXPR [ (2, _0) (2, _1) (1, _5) 1 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = merge_expressions(circuit.clone()); + assert_eq!(circuit, optimized_circuit); } #[test] fn does_not_eliminate_witnesses_returned_from_circuit() { - let opcodes = vec![ - Opcode::AssertZero(Expression { - mul_terms: vec![(FieldElement::from(-1i128), Witness(0), Witness(0))], - linear_combinations: vec![(FieldElement::from(1i128), Witness(1))], - q_c: FieldElement::zero(), - }), - Opcode::AssertZero(Expression { - mul_terms: Vec::new(), - linear_combinations: vec![ - (FieldElement::from(-1i128), Witness(1)), - (FieldElement::from(1i128), Witness(2)), - ], - q_c: FieldElement::zero(), - }), - ]; - // Witness(1) could be eliminated because it's only used by 2 opcodes. - - let mut private_parameters = BTreeSet::new(); - private_parameters.insert(Witness(0)); - - let mut return_values = BTreeSet::new(); - return_values.insert(Witness(1)); - return_values.insert(Witness(2)); - - let circuit = Circuit { - current_witness_index: 2, - expression_width: ExpressionWidth::Bounded { width: 4 }, - opcodes, - private_parameters, - public_parameters: PublicInputs::default(), - return_values: PublicInputs(return_values), - assert_messages: Default::default(), - }; - - let mut merge_optimizer = MergeExpressionsOptimizer::new(); - let acir_opcode_positions = vec![0; 20]; - let (opcodes, _) = - merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions); - - assert_eq!(opcodes.len(), 2); + let src = " + current witness index : _2 + private parameters indices : [_0] + public parameters indices : [] + return value indices : [_1, _2] + EXPR [ (-1, _0, _0) (1, _1) 0 ] + EXPR [ (-1, _1) (1, _2) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = merge_expressions(circuit.clone()); + assert_eq!(circuit, optimized_circuit); } #[test] fn does_not_attempt_to_merge_into_previous_opcodes() { - let opcodes = vec![ - Opcode::AssertZero(Expression { - mul_terms: vec![(FieldElement::one(), Witness(0), Witness(0))], - linear_combinations: vec![(-FieldElement::one(), Witness(4))], - q_c: FieldElement::zero(), - }), - Opcode::AssertZero(Expression { - mul_terms: vec![(FieldElement::one(), Witness(0), Witness(1))], - linear_combinations: vec![(FieldElement::one(), Witness(5))], - q_c: FieldElement::zero(), - }), - Opcode::AssertZero(Expression { - mul_terms: Vec::new(), - linear_combinations: vec![ - (-FieldElement::one(), Witness(2)), - (FieldElement::one(), Witness(4)), - (FieldElement::one(), Witness(5)), - ], - q_c: FieldElement::zero(), - }), - Opcode::AssertZero(Expression { - mul_terms: Vec::new(), - linear_combinations: vec![ - (FieldElement::one(), Witness(2)), - (-FieldElement::one(), Witness(3)), - (FieldElement::one(), Witness(4)), - (FieldElement::one(), Witness(5)), - ], - q_c: FieldElement::zero(), - }), - Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { - input: FunctionInput::witness(Witness(3), 32), - }), - ]; - - let mut private_parameters = BTreeSet::new(); - private_parameters.insert(Witness(0)); - private_parameters.insert(Witness(1)); - let circuit = Circuit { - current_witness_index: 5, - expression_width: ExpressionWidth::Bounded { width: 4 }, - opcodes, - private_parameters, - public_parameters: PublicInputs::default(), - return_values: PublicInputs::default(), - assert_messages: Default::default(), - }; - must_check_circuit(circuit); + let src = " + current witness index : _5 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (1, _0, _0) (-1, _4) 0 ] + EXPR [ (1, _0, _1) (1, _5) 0 ] + EXPR [ (-1, _2) (1, _4) (1, _5) 0 ] + EXPR [ (1, _2) (-1, _3) (1, _4) (1, _5) 0 ] + BLACKBOX::RANGE [(_3, 32)] [] + "; + let circuit = Circuit::from_str(src).unwrap(); + + let optimized_circuit = merge_expressions(circuit); + assert_circuit_snapshot!(optimized_circuit, @r" + current witness index : _5 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [] + EXPR [ (1, _0, _1) (1, _5) 0 ] + EXPR [ (2, _0, _0) (-1, _3) (2, _5) 0 ] + BLACKBOX::RANGE [(_3, 32)] [] + "); } #[test] @@ -451,43 +363,18 @@ mod tests { // Previously we would not track the usage of witness 4 in the output of the blackbox function. // We would then merge the final two opcodes losing the check that the brillig call must match // with `_0 ^ _1`. - - let circuit: Circuit = Circuit { - current_witness_index: 7, - opcodes: vec![ - Opcode::BrilligCall { - id: BrilligFunctionId(0), - inputs: Vec::new(), - outputs: vec![BrilligOutputs::Simple(Witness(3))], - predicate: None, - }, - Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND { - lhs: FunctionInput::witness(Witness(0), 8), - rhs: FunctionInput::witness(Witness(1), 8), - output: Witness(4), - }), - Opcode::AssertZero(Expression { - linear_combinations: vec![ - (FieldElement::one(), Witness(3)), - (-FieldElement::one(), Witness(4)), - ], - ..Default::default() - }), - Opcode::AssertZero(Expression { - linear_combinations: vec![ - (-FieldElement::one(), Witness(2)), - (FieldElement::one(), Witness(4)), - ], - ..Default::default() - }), - ], - expression_width: ExpressionWidth::Bounded { width: 4 }, - private_parameters: BTreeSet::from([Witness(0), Witness(1)]), - return_values: PublicInputs(BTreeSet::from([Witness(2)])), - ..Default::default() - }; - - let new_circuit = must_check_circuit(circuit.clone()); - assert_eq!(circuit, new_circuit); + let src = " + current witness index : _7 + private parameters indices : [_0, _1] + public parameters indices : [] + return value indices : [_2] + BRILLIG CALL func 0: inputs: [], outputs: [_3] + BLACKBOX::AND [(_0, 8), (_1, 8)] [_4] + EXPR [ (1, _3) (-1, _4) 0 ] + EXPR [ (-1, _2) (1, _4) 0 ] + "; + let circuit = Circuit::from_str(src).unwrap(); + let optimized_circuit = merge_expressions(circuit.clone()); + assert_eq!(circuit, optimized_circuit); } }