Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions acvm-repo/acir/src/native_types/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,116 @@ mod tests {
assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10");
}

#[test]
fn add_mul_with_zero_coefficient() {
// When k=0, should return a clone of self
let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap();
let b = Expression::from_str("4*w2*w3 + 6*w2 + 7").unwrap();
let k = FieldElement::zero();

let result = a.add_mul(k, &b);
assert_eq!(result, a);
}

#[test]
fn add_mul_when_self_is_const() {
// When self is a constant, should return k*b + constant
let a = Expression::from_field(FieldElement::from(5u128));
let b = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
let k = FieldElement::from(2u128);

let result = a.add_mul(k, &b);
assert_eq!(result.to_string(), "4*w1*w2 + 6*w1 + 13");
}

#[test]
fn add_mul_when_b_is_const() {
// When b is a constant, should return self + k*constant
let a = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap();
let b = Expression::from_field(FieldElement::from(5u128));
let k = FieldElement::from(3u128);

let result = a.add_mul(k, &b);
assert_eq!(result.to_string(), "2*w1*w2 + 3*w1 + 19");
}

#[test]
fn add_mul_merges_linear_terms() {
// Test that linear terms with same witness are merged correctly
let a = Expression::from_str("5*w1 + 3*w2").unwrap();
let b = Expression::from_str("2*w1 + 4*w3").unwrap();
let k = FieldElement::from(2u128);

let result = a.add_mul(k, &b);
// 5*w1 + 3*w2 + 2*(2*w1 + 4*w3) = 5*w1 + 3*w2 + 4*w1 + 8*w3 = 9*w1 + 3*w2 + 8*w3
assert_eq!(result.to_string(), "9*w1 + 3*w2 + 8*w3");
}

#[test]
fn add_mul_merges_mul_terms() {
// Test that multiplication terms with same witness pair are merged correctly
let a = Expression::from_str("5*w1*w2 + 3*w3*w4").unwrap();
let b = Expression::from_str("2*w1*w2 + 4*w5*w6").unwrap();
let k = FieldElement::from(3u128);

let result = a.add_mul(k, &b);
// 5*w1*w2 + 3*w3*w4 + 3*(2*w1*w2 + 4*w5*w6) = 5*w1*w2 + 3*w3*w4 + 6*w1*w2 + 12*w5*w6
// = 11*w1*w2 + 3*w3*w4 + 12*w5*w6
assert_eq!(result.to_string(), "11*w1*w2 + 3*w3*w4 + 12*w5*w6");
}

#[test]
fn add_mul_cancels_terms_to_zero() {
// Test that terms that cancel out are removed
let a = Expression::from_str("6*w1 + 3*w1*w2").unwrap();
let b = Expression::from_str("3*w1 + w1*w2").unwrap();
let k = FieldElement::from(-2i128);

let result = a.add_mul(k, &b);
// 6*w1 + 3*w1*w2 + (-2)*(3*w1 + w1*w2) = 6*w1 + 3*w1*w2 - 6*w1 - 2*w1*w2
// = w1*w2
assert_eq!(result.to_string(), "w1*w2");
}

#[test]
fn add_mul_maintains_sorted_order() {
// Test that the result maintains sorted order for deterministic output
let a = Expression::from_str("w5 + w1*w3").unwrap();
let b = Expression::from_str("w2 + w0*w1").unwrap();
let k = FieldElement::one();

let result = a.add_mul(k, &b);
// Result should have terms in sorted order
assert!(result.is_sorted());
assert_eq!(result.to_string(), "w0*w1 + w1*w3 + w2 + w5");
}

#[test]
fn add_mul_with_constant_terms() {
// Test handling of constant terms
let a = Expression::from_str("2*w1 + 10").unwrap();
let b = Expression::from_str("3*w2 + 5").unwrap();
let k = FieldElement::from(4u128);

let result = a.add_mul(k, &b);
// 2*w1 + 10 + 4*(3*w2 + 5) = 2*w1 + 10 + 12*w2 + 20 = 2*w1 + 12*w2 + 30
assert_eq!(result.to_string(), "2*w1 + 12*w2 + 30");
}

#[test]
fn add_mul_complex_expression() {
// Test a complex expression with all types of terms
let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11").unwrap();
let b = Expression::from_str("w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13").unwrap();
let k = FieldElement::from(2u128);

let result = a.add_mul(k, &b);
// 2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11 + 2*(w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13)
// = 2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11 + 2*w1*w2 + 8*w5*w6 + 4*w1 + 12*w5 + 26
// = 4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37
assert_eq!(result.to_string(), "4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37");
}

#[test]
fn display_zero() {
let zero = Expression::<FieldElement>::default();
Expand Down
Loading
Loading