diff --git a/acvm-repo/acvm/src/pwg/arithmetic.rs b/acvm-repo/acvm/src/pwg/arithmetic.rs index a2921bcbc9b..0708dfb2dc5 100644 --- a/acvm-repo/acvm/src/pwg/arithmetic.rs +++ b/acvm-repo/acvm/src/pwg/arithmetic.rs @@ -59,7 +59,7 @@ impl ExpressionSolver { Ok(()) } } else { - let assignment = -total_sum / (q + b); + let assignment = -quick_invert(total_sum, q + b); insert_value(&w1, assignment, initial_witness) } } else { @@ -88,7 +88,7 @@ impl ExpressionSolver { Ok(()) } } else { - let assignment = -(total_sum / partial_prod); + let assignment = -quick_invert(total_sum, partial_prod); insert_value(&unknown_var, assignment, initial_witness) } } @@ -122,7 +122,7 @@ impl ExpressionSolver { Ok(()) } } else { - let assignment = -(total_sum / coeff); + let assignment = -quick_invert(total_sum, coeff); insert_value(&unknown_var, assignment, initial_witness) } } @@ -248,11 +248,34 @@ impl ExpressionSolver { } } +/// A wrapper around field division which skips the inversion if the denominator +/// is ±1. +/// +/// Field inversion is the most significant cost of solving [`Opcode::AssertZero`][acir::circuit::opcodes::Opcode::AssertZero] +/// opcodes, we can avoid this in the situation +fn quick_invert(numerator: F, denominator: F) -> F { + if denominator == F::one() { + numerator + } else if denominator == -F::one() { + -numerator + } else { + numerator / denominator + } +} + #[cfg(test)] mod tests { use super::*; use acir::FieldElement; + #[test] + /// Sanity check for the special cases of [`quick_invert`] + fn quick_invert_matches_slow_invert() { + let numerator = FieldElement::from_be_bytes_reduce("hello_world".as_bytes()); + assert_eq!(quick_invert(numerator, FieldElement::one()), numerator / FieldElement::one()); + assert_eq!(quick_invert(numerator, -FieldElement::one()), numerator / -FieldElement::one()); + } + #[test] fn solves_simple_assignment() { let a = Witness(0);