diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index 08576b4a7..572346fe0 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -26,6 +26,7 @@ k256 = { version = "0.7.2", features = [ "arithmetic", ] } indexmap = "1.7.0" +thiserror = "1.0.21" [features] bn254 = ["acir_field/bn254"] diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index c6a46d986..c5857d534 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -15,6 +15,7 @@ use blake2::digest::FixedOutput; use crate::pwg::{arithmetic::ArithmeticSolver, logic::LogicSolver}; use num_bigint::BigUint; use num_traits::{One, Zero}; +use thiserror::Error; // re-export acir pub use acir; @@ -22,11 +23,18 @@ pub use acir::FieldElement; #[derive(PartialEq, Eq, Debug)] pub enum OpcodeResolution { - Resolved, // Opcode is solved - Skip, // Opcode cannot be solved - UnknownError(String), // Generic error - UnsupportedBlackBoxFunc(BlackBoxFunc), // Unsupported black box function - UnsatisfiedConstrain, // Opcode is not satisfied + Resolved, // Opcode is solved + Skip, // Opcode cannot be solved +} + +#[derive(PartialEq, Eq, Debug, Error)] +pub enum OpcodeResolutionError { + #[error("{0}")] + UnknownError(String), + #[error("backend does not currently support the {0} opcode. ACVM does not currently fall back to arithmetic gates.")] + UnsupportedBlackBoxFunc(BlackBoxFunc), + #[error("could not satisfy all constraints")] + UnsatisfiedConstrain, } pub trait Backend: SmartContract + ProofSystemCompiler + PartialWitnessGenerator {} @@ -39,82 +47,73 @@ pub trait PartialWitnessGenerator { &self, initial_witness: &mut BTreeMap, gates: Vec, - ) -> OpcodeResolution { + ) -> Result<(), OpcodeResolutionError> { if gates.is_empty() { - return OpcodeResolution::Resolved; + return Ok(()); } let mut unsolved_gates: Vec = Vec::new(); for gate in gates.into_iter() { - let unsolved = match &gate { - Opcode::Arithmetic(arith) => { - let result = ArithmeticSolver::solve(initial_witness, arith); - match result { - OpcodeResolution::Resolved => false, - OpcodeResolution::Skip => true, - _ => return result, - } - } - Opcode::BlackBoxFuncCall(gc) if gc.name == BlackBoxFunc::RANGE => { - // TODO: this consistency check can be moved to a general function - let defined_input_size = BlackBoxFunc::RANGE - .definition() - .input_size - .fixed_size() - .expect("infallible: input for range gate is fixed"); + let resolution = match &gate { + Opcode::Arithmetic(arith) => ArithmeticSolver::solve(initial_witness, arith)?, + Opcode::BlackBoxFuncCall(bb_func) => match bb_func.name { + BlackBoxFunc::RANGE => { + // TODO: this consistency check can be moved to a general function + let defined_input_size = BlackBoxFunc::RANGE + .definition() + .input_size + .fixed_size() + .expect("infallible: input for range gate is fixed"); - if gc.inputs.len() != defined_input_size as usize { - return OpcodeResolution::UnknownError( - "defined input size does not equal given input size".to_string(), - ); - } + if bb_func.inputs.len() != defined_input_size as usize { + return Err(OpcodeResolutionError::UnknownError( + "defined input size does not equal given input size".to_string(), + )); + } - // For the range constraint, we know that the input size should be one - assert_eq!(defined_input_size, 1); + // For the range constraint, we know that the input size should be one + assert_eq!(defined_input_size, 1); - let input = gc - .inputs - .first() - .expect("infallible: checked that input size is 1"); + let input = bb_func + .inputs + .first() + .expect("infallible: checked that input size is 1"); - if let Some(w_value) = initial_witness.get(&input.witness) { - if w_value.num_bits() > input.num_bits { - return OpcodeResolution::UnsatisfiedConstrain; + if let Some(w_value) = initial_witness.get(&input.witness) { + if w_value.num_bits() > input.num_bits { + return Err(OpcodeResolutionError::UnsatisfiedConstrain); + } + OpcodeResolution::Resolved + } else { + OpcodeResolution::Skip } - false - } else { - true } - } - Opcode::BlackBoxFuncCall(gc) if gc.name == BlackBoxFunc::AND => { - !LogicSolver::solve_and_gate(initial_witness, gc) - } - Opcode::BlackBoxFuncCall(gc) if gc.name == BlackBoxFunc::XOR => { - !LogicSolver::solve_xor_gate(initial_witness, gc) - } - Opcode::BlackBoxFuncCall(gc) => { - let mut unsolvable = false; - for i in &gc.inputs { - if !initial_witness.contains_key(&i.witness) { - unsolvable = true; - break; + BlackBoxFunc::AND => LogicSolver::solve_and_gate(initial_witness, bb_func), + BlackBoxFunc::XOR => LogicSolver::solve_xor_gate(initial_witness, bb_func), + _ => { + let mut unsolvable = false; + for i in &bb_func.inputs { + if !initial_witness.contains_key(&i.witness) { + unsolvable = true; + break; + } + } + if unsolvable { + OpcodeResolution::Skip + } else if let Err(op) = Self::solve_gadget_call(initial_witness, bb_func) { + return Err(OpcodeResolutionError::UnsupportedBlackBoxFunc(op)); + } else { + OpcodeResolution::Resolved } } - if unsolvable { - true - } else if let Err(op) = Self::solve_gadget_call(initial_witness, gc) { - return OpcodeResolution::UnsupportedBlackBoxFunc(op); - } else { - false - } - } + }, Opcode::Directive(directive) => match directive { Directive::Invert { x, result } => match initial_witness.get(x) { - None => true, + None => OpcodeResolution::Skip, Some(val) => { let inverse = val.inverse(); initial_witness.insert(*result, inverse); - false + OpcodeResolution::Resolved } }, Directive::Quotient { @@ -147,12 +146,12 @@ pub trait PartialWitnessGenerator { *r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), ); - false + OpcodeResolution::Resolved } else { - true + OpcodeResolution::Skip } } - _ => true, + _ => OpcodeResolution::Skip, } } Directive::Truncate { a, b, c, bit_size } => match initial_witness.get(a) { @@ -171,9 +170,9 @@ pub trait PartialWitnessGenerator { *c, FieldElement::from_be_bytes_reduce(&int_c.to_bytes_be()), ); - false + OpcodeResolution::Resolved } - _ => true, + _ => OpcodeResolution::Skip, }, Directive::ToBits { a, b, bit_size } => { match Self::get_value(a, initial_witness) { @@ -192,14 +191,16 @@ pub trait PartialWitnessGenerator { } std::collections::btree_map::Entry::Occupied(e) => { if e.get() != &v { - return OpcodeResolution::UnsatisfiedConstrain; + return Err( + OpcodeResolutionError::UnsatisfiedConstrain, + ); } } } } - false + OpcodeResolution::Resolved } - _ => true, + _ => OpcodeResolution::Skip, } } Directive::ToBytes { a, b, byte_size } => { @@ -216,14 +217,16 @@ pub trait PartialWitnessGenerator { } std::collections::btree_map::Entry::Occupied(e) => { if e.get() != &v { - return OpcodeResolution::UnsatisfiedConstrain; + return Err( + OpcodeResolutionError::UnsatisfiedConstrain, + ); } } } } - false + OpcodeResolution::Resolved } - _ => true, + _ => OpcodeResolution::Skip, } } Directive::Oddrange { a, b, r, bit_size } => match initial_witness.get(a) { @@ -231,7 +234,7 @@ pub trait PartialWitnessGenerator { let int_a = BigUint::from_bytes_be(&val_a.to_bytes()); let pow: BigUint = BigUint::one() << (bit_size - 1); if int_a >= (&pow << 1) { - return OpcodeResolution::UnsatisfiedConstrain; + return Err(OpcodeResolutionError::UnsatisfiedConstrain); } let bb = &int_a & &pow; let int_r = &int_a - &bb; @@ -245,13 +248,13 @@ pub trait PartialWitnessGenerator { *r, FieldElement::from_be_bytes_reduce(&int_r.to_bytes_be()), ); - false + OpcodeResolution::Resolved } - _ => true, + _ => OpcodeResolution::Skip, }, }, }; - if unsolved { + if resolution == OpcodeResolution::Skip { unsolved_gates.push(gate); } } diff --git a/acvm/src/pwg/arithmetic.rs b/acvm/src/pwg/arithmetic.rs index d734e2e0c..0fd4107a7 100644 --- a/acvm/src/pwg/arithmetic.rs +++ b/acvm/src/pwg/arithmetic.rs @@ -2,7 +2,7 @@ use acir::native_types::{Expression, Witness}; use acir_field::FieldElement; use std::collections::BTreeMap; -use crate::OpcodeResolution; +use crate::{OpcodeResolution, OpcodeResolutionError}; /// An Arithmetic solver will take a Circuit's arithmetic gates with witness assignments /// and create the other witness variables @@ -26,7 +26,7 @@ impl ArithmeticSolver { pub fn solve( initial_witness: &mut BTreeMap, gate: &Expression, - ) -> OpcodeResolution { + ) -> Result { // Evaluate multiplication term let mul_result = ArithmeticSolver::solve_mul_term(gate, initial_witness); // Evaluate the fan-in terms @@ -34,7 +34,7 @@ impl ArithmeticSolver { match (mul_result, gate_status) { (MulTerm::TooManyUnknowns, _) | (_, GateStatus::GateUnsolvable) => { - OpcodeResolution::Skip + Ok(OpcodeResolution::Skip) } (MulTerm::OneUnknown(q, w1), GateStatus::GateSolvable(a, (b, w2))) => { if w1 == w2 { @@ -42,18 +42,18 @@ impl ArithmeticSolver { let total_sum = a + gate.q_c; if (q + b).is_zero() { if !total_sum.is_zero() { - OpcodeResolution::UnsatisfiedConstrain + Err(OpcodeResolutionError::UnsatisfiedConstrain) } else { - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } else { let assignment = -total_sum / (q + b); // Add this into the witness assignments initial_witness.insert(w1, assignment); - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } else { - OpcodeResolution::Skip + Ok(OpcodeResolution::Skip) } } (MulTerm::OneUnknown(partial_prod, unknown_var), GateStatus::GateSatisfied(sum)) => { @@ -64,24 +64,24 @@ impl ArithmeticSolver { let total_sum = sum + gate.q_c; if partial_prod.is_zero() { if !total_sum.is_zero() { - OpcodeResolution::UnsatisfiedConstrain + Err(OpcodeResolutionError::UnsatisfiedConstrain) } else { - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } else { let assignment = -(total_sum / partial_prod); // Add this into the witness assignments initial_witness.insert(unknown_var, assignment); - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } (MulTerm::Solved(a), GateStatus::GateSatisfied(b)) => { // All the variables in the MulTerm are solved and the Fan-in is also solved // There is nothing to solve if !(a + b + gate.q_c).is_zero() { - OpcodeResolution::UnsatisfiedConstrain + Err(OpcodeResolutionError::UnsatisfiedConstrain) } else { - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } ( @@ -94,15 +94,15 @@ impl ArithmeticSolver { let total_sum = total_prod + partial_sum + gate.q_c; if coeff.is_zero() { if !total_sum.is_zero() { - OpcodeResolution::UnsatisfiedConstrain + Err(OpcodeResolutionError::UnsatisfiedConstrain) } else { - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } else { let assignment = -(total_sum / coeff); // Add this into the witness assignments initial_witness.insert(unknown_var, assignment); - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) } } } @@ -221,11 +221,11 @@ fn arithmetic_smoke_test() { assert_eq!( ArithmeticSolver::solve(&mut values, &gate_a), - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) ); assert_eq!( ArithmeticSolver::solve(&mut values, &gate_b), - OpcodeResolution::Resolved + Ok(OpcodeResolution::Resolved) ); assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128)); diff --git a/acvm/src/pwg/logic.rs b/acvm/src/pwg/logic.rs index d35f316e0..12276b0a0 100644 --- a/acvm/src/pwg/logic.rs +++ b/acvm/src/pwg/logic.rs @@ -3,6 +3,8 @@ use acir::native_types::Witness; use acir_field::FieldElement; use std::collections::BTreeMap; +use crate::OpcodeResolution; + pub struct LogicSolver; impl LogicSolver { @@ -14,13 +16,13 @@ impl LogicSolver { result: Witness, num_bits: u32, is_xor_gate: bool, - ) -> bool { + ) -> OpcodeResolution { let w_l = initial_witness.get(a); let w_r = initial_witness.get(b); let (w_l_value, w_r_value) = match (w_l, w_r) { (Some(w_l_value), Some(w_r_value)) => (w_l_value, w_r_value), - (_, _) => return false, + (_, _) => return OpcodeResolution::Skip, }; if is_xor_gate { @@ -30,20 +32,20 @@ impl LogicSolver { let assignment = w_l_value.and(w_r_value, num_bits); initial_witness.insert(result, assignment); } - true + OpcodeResolution::Resolved } pub fn solve_and_gate( initial_witness: &mut BTreeMap, gate: &BlackBoxFuncCall, - ) -> bool { + ) -> OpcodeResolution { let (a, b, result, num_bits) = extract_input_output(gate); LogicSolver::solve_logic_gate(initial_witness, &a, &b, result, num_bits, false) } pub fn solve_xor_gate( initial_witness: &mut BTreeMap, gate: &BlackBoxFuncCall, - ) -> bool { + ) -> OpcodeResolution { let (a, b, result, num_bits) = extract_input_output(gate); LogicSolver::solve_logic_gate(initial_witness, &a, &b, result, num_bits, true) }