Skip to content
111 changes: 32 additions & 79 deletions acvm-repo/acir/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,51 +409,21 @@ impl PublicInputs {

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use super::{
Circuit, Compression, Opcode, PublicInputs,
opcodes::{BlackBoxFuncCall, FunctionInput},
};
use crate::{circuit::Program, native_types::Witness};
use super::{Circuit, Compression};
use crate::circuit::Program;
use acir_field::{AcirField, FieldElement};
use serde::{Deserialize, Serialize};

fn and_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
lhs: FunctionInput::Witness(Witness(1)),
rhs: FunctionInput::Witness(Witness(2)),
num_bits: 4,
output: Witness(3),
})
}

fn range_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput::Witness(Witness(1)),
num_bits: 8,
})
}

fn keccakf1600_opcode<F: AcirField>() -> Opcode<F> {
let inputs: Box<[FunctionInput<F>; 25]> =
Box::new(std::array::from_fn(|i| FunctionInput::Witness(Witness(i as u32 + 1))));
let outputs: Box<[Witness; 25]> = Box::new(std::array::from_fn(|i| Witness(i as u32 + 26)));

Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs })
}

#[test]
fn serialization_roundtrip() {
let circuit = Circuit {
function_name: "test".to_string(),
current_witness_index: 5,
opcodes: vec![and_opcode::<FieldElement>(), range_opcode()],
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])),
assert_messages: Default::default(),
};
let src = "
private parameters: []
public parameters: [w2, w12]
return values: [w4, w12]
BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
BLACKBOX::RANGE input: w1, bits: 8
";
let circuit = Circuit::from_str(src).unwrap();
let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };

fn read_write<F: Serialize + for<'a> Deserialize<'a> + AcirField>(
Expand All @@ -470,24 +440,16 @@ mod tests {

#[test]
fn test_serialize() {
let circuit = Circuit {
function_name: "test".to_string(),
current_witness_index: 0,
opcodes: vec![
Opcode::AssertZero(crate::native_types::Expression {
mul_terms: vec![],
linear_combinations: vec![],
q_c: FieldElement::from(8u128),
}),
range_opcode(),
and_opcode(),
keccakf1600_opcode(),
],
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
assert_messages: Default::default(),
};
let src = "
private parameters: []
public parameters: [w2]
return values: [w2]
ASSERT 0 = 8
BLACKBOX::RANGE input: w1, bits: 8
BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
";
let circuit = Circuit::from_str(src).unwrap();
let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };

let json = serde_json::to_string_pretty(&program).unwrap();
Expand Down Expand Up @@ -516,27 +478,18 @@ mod tests {

#[test]
fn circuit_display_snapshot() {
let circuit = Circuit {
function_name: "test".to_string(),
current_witness_index: 3,
opcodes: vec![
Opcode::AssertZero(crate::native_types::Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::from(2u128), Witness(1))],
q_c: FieldElement::from(8u128),
}),
range_opcode(),
and_opcode(),
keccakf1600_opcode(),
],
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])),
assert_messages: Default::default(),
};

// We want to make sure that we witness indices are displayed in a unified format.
// All witnesses are expected to be formatted as `_{witness_index}`.
let src = "
private parameters: []
public parameters: [w2]
return values: [w2]
ASSERT 0 = 2*w1 + 8
BLACKBOX::RANGE input: w1, bits: 8
BLACKBOX::AND lhs: w1, rhs: w2, output: w3, bits: 4
BLACKBOX::KECCAKF1600 inputs: [w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15, w16, w17, w18, w19, w20, w21, w22, w23, w24, w25], outputs: [w26, w27, w28, w29, w30, w31, w32, w33, w34, w35, w36, w37, w38, w39, w40, w41, w42, w43, w44, w45, w46, w47, w48, w49, w50]
";
let circuit = Circuit::from_str(src).unwrap();

// All witnesses are expected to be formatted as `w{witness_index}`.
insta::assert_snapshot!(
circuit.to_string(),
@r"
Expand Down
1 change: 1 addition & 0 deletions acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub use acir_field::{AcirField, FieldElement};
pub use brillig;
pub use circuit::black_box_functions::BlackBoxFunc;
pub use circuit::opcodes::InvalidInputBitSize;
pub use parser::parse_opcodes;

#[cfg(test)]
mod reflection {
Expand Down
31 changes: 4 additions & 27 deletions acvm-repo/acir/src/native_types/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,40 +510,17 @@ fn display_term<F: AcirField, const N: usize>(
#[cfg(test)]
mod tests {
use super::*;
use acir_field::{AcirField, FieldElement};
use acir_field::FieldElement;

#[test]
fn add_mul_smoke_test() {
let a = Expression {
mul_terms: vec![(FieldElement::from(2u128), Witness(1), Witness(2))],
..Default::default()
};
let a = Expression::from_str("2*w1*w2").unwrap();

let k = FieldElement::from(10u128);

let b = Expression {
mul_terms: vec![
(FieldElement::from(3u128), Witness(0), Witness(2)),
(FieldElement::from(3u128), Witness(1), Witness(2)),
(FieldElement::from(4u128), Witness(4), Witness(5)),
],
linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
q_c: FieldElement::one(),
};
let b = Expression::from_str("3*w0*w2 + 3*w1*w2 + 4*w4*w5 + 4*w4 + 1").unwrap();

let result = a.add_mul(k, &b);
assert_eq!(
result,
Expression {
mul_terms: vec![
(FieldElement::from(30u128), Witness(0), Witness(2)),
(FieldElement::from(32u128), Witness(1), Witness(2)),
(FieldElement::from(40u128), Witness(4), Witness(5)),
],
linear_combinations: vec![(FieldElement::from(40u128), Witness(4))],
q_c: FieldElement::from(10u128)
}
);
assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10");
}

#[test]
Expand Down
57 changes: 9 additions & 48 deletions acvm-repo/acir/src/native_types/expression/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,64 +207,25 @@ fn single_mul<F: AcirField>(w: Witness, b: &Expression<F>) -> Expression<F> {

#[cfg(test)]
mod tests {
use super::*;
use acir_field::{AcirField, FieldElement};
use crate::native_types::Expression;

#[test]
fn add_smoke_test() {
let a = Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
q_c: FieldElement::from(2u128),
};

let b = Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
q_c: FieldElement::one(),
};

assert_eq!(
&a + &b,
Expression {
mul_terms: vec![],
linear_combinations: vec![
(FieldElement::from(2u128), Witness(2)),
(FieldElement::from(4u128), Witness(4))
],
q_c: FieldElement::from(3u128)
}
);
let a = Expression::from_str("2*w2 + 2").unwrap();
let b = Expression::from_str("4*w4 + 1").unwrap();
let result = Expression::from_str("2*w2 + 4*w4 + 3").unwrap();
assert_eq!(&a + &b, result);

// Enforce commutativity
assert_eq!(&a + &b, &b + &a);
}

#[test]
fn mul_smoke_test() {
let a = Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::from(2u128), Witness(2))],
q_c: FieldElement::from(2u128),
};

let b = Expression {
mul_terms: vec![],
linear_combinations: vec![(FieldElement::from(4u128), Witness(4))],
q_c: FieldElement::one(),
};

assert_eq!(
(&a * &b).unwrap(),
Expression {
mul_terms: vec![(FieldElement::from(8u128), Witness(2), Witness(4)),],
linear_combinations: vec![
(FieldElement::from(2u128), Witness(2)),
(FieldElement::from(8u128), Witness(4))
],
q_c: FieldElement::from(2u128)
}
);
let a = Expression::from_str("2*w2 + 2").unwrap();
let b = Expression::from_str("4*w4 + 1").unwrap();
let result = Expression::from_str("8*w2*w4 + 2*w2 + 8*w4 + 2").unwrap();
assert_eq!((&a * &b).unwrap(), result);

// Enforce commutativity
assert_eq!(&a * &b, &b * &a);
Expand Down
30 changes: 30 additions & 0 deletions acvm-repo/acir/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ impl Circuit<FieldElement> {
}
}

impl FromStr for Expression<FieldElement> {
type Err = AcirParserErrorWithSource;

fn from_str(src: &str) -> Result<Self, Self::Err> {
Self::from_str_impl(src)
}
}

impl Expression<FieldElement> {
/// Creates a [Expression] object from the given string.
#[allow(clippy::should_implement_trait)]
pub fn from_str(src: &str) -> Result<Self, AcirParserErrorWithSource> {
FromStr::from_str(src)
}

pub fn from_str_impl(src: &str) -> Result<Self, AcirParserErrorWithSource> {
let mut parser =
Parser::new(src).map_err(|err| AcirParserErrorWithSource::parse_error(err, src))?;
parser
.parse_arithmetic_expression()
.map_err(|err| AcirParserErrorWithSource::parse_error(err, src))
}
}

pub fn parse_opcodes(src: &str) -> Result<Vec<Opcode<FieldElement>>, AcirParserErrorWithSource> {
let mut parser =
Parser::new(src).map_err(|err| AcirParserErrorWithSource::parse_error(err, src))?;
parser.parse_opcodes().map_err(|err| AcirParserErrorWithSource::parse_error(err, src))
}

struct Parser<'a> {
lexer: Lexer<'a>,
token: SpannedToken,
Expand Down
Loading
Loading