diff --git a/acvm-repo/acir/src/native_types/expression/mod.rs b/acvm-repo/acir/src/native_types/expression/mod.rs index 037cd538c59..aac11e7f452 100644 --- a/acvm-repo/acir/src/native_types/expression/mod.rs +++ b/acvm-repo/acir/src/native_types/expression/mod.rs @@ -528,6 +528,116 @@ mod tests { assert_eq!(result.to_string(), "30*w0*w2 + 32*w1*w2 + 40*w4*w5 + 40*w4 + 10"); } + #[test] + fn add_mul_with_zero_coefficient() { + // When k=0, should return a clone of self + let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap(); + let b = Expression::from_str("4*w2*w3 + 6*w2 + 7").unwrap(); + let k = FieldElement::zero(); + + let result = a.add_mul(k, &b); + assert_eq!(result, a); + } + + #[test] + fn add_mul_when_self_is_const() { + // When self is a constant, should return k*b + constant + let a = Expression::from_field(FieldElement::from(5u128)); + let b = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap(); + let k = FieldElement::from(2u128); + + let result = a.add_mul(k, &b); + assert_eq!(result.to_string(), "4*w1*w2 + 6*w1 + 13"); + } + + #[test] + fn add_mul_when_b_is_const() { + // When b is a constant, should return self + k*constant + let a = Expression::from_str("2*w1*w2 + 3*w1 + 4").unwrap(); + let b = Expression::from_field(FieldElement::from(5u128)); + let k = FieldElement::from(3u128); + + let result = a.add_mul(k, &b); + assert_eq!(result.to_string(), "2*w1*w2 + 3*w1 + 19"); + } + + #[test] + fn add_mul_merges_linear_terms() { + // Test that linear terms with same witness are merged correctly + let a = Expression::from_str("5*w1 + 3*w2").unwrap(); + let b = Expression::from_str("2*w1 + 4*w3").unwrap(); + let k = FieldElement::from(2u128); + + let result = a.add_mul(k, &b); + // 5*w1 + 3*w2 + 2*(2*w1 + 4*w3) = 5*w1 + 3*w2 + 4*w1 + 8*w3 = 9*w1 + 3*w2 + 8*w3 + assert_eq!(result.to_string(), "9*w1 + 3*w2 + 8*w3"); + } + + #[test] + fn add_mul_merges_mul_terms() { + // Test that multiplication terms with same witness pair are merged correctly + let a = Expression::from_str("5*w1*w2 + 3*w3*w4").unwrap(); + let b = Expression::from_str("2*w1*w2 + 4*w5*w6").unwrap(); + let k = FieldElement::from(3u128); + + let result = a.add_mul(k, &b); + // 5*w1*w2 + 3*w3*w4 + 3*(2*w1*w2 + 4*w5*w6) = 5*w1*w2 + 3*w3*w4 + 6*w1*w2 + 12*w5*w6 + // = 11*w1*w2 + 3*w3*w4 + 12*w5*w6 + assert_eq!(result.to_string(), "11*w1*w2 + 3*w3*w4 + 12*w5*w6"); + } + + #[test] + fn add_mul_cancels_terms_to_zero() { + // Test that terms that cancel out are removed + let a = Expression::from_str("6*w1 + 3*w1*w2").unwrap(); + let b = Expression::from_str("3*w1 + w1*w2").unwrap(); + let k = FieldElement::from(-2i128); + + let result = a.add_mul(k, &b); + // 6*w1 + 3*w1*w2 + (-2)*(3*w1 + w1*w2) = 6*w1 + 3*w1*w2 - 6*w1 - 2*w1*w2 + // = w1*w2 + assert_eq!(result.to_string(), "w1*w2"); + } + + #[test] + fn add_mul_maintains_sorted_order() { + // Test that the result maintains sorted order for deterministic output + let a = Expression::from_str("w5 + w1*w3").unwrap(); + let b = Expression::from_str("w2 + w0*w1").unwrap(); + let k = FieldElement::one(); + + let result = a.add_mul(k, &b); + // Result should have terms in sorted order + assert!(result.is_sorted()); + assert_eq!(result.to_string(), "w0*w1 + w1*w3 + w2 + w5"); + } + + #[test] + fn add_mul_with_constant_terms() { + // Test handling of constant terms + let a = Expression::from_str("2*w1 + 10").unwrap(); + let b = Expression::from_str("3*w2 + 5").unwrap(); + let k = FieldElement::from(4u128); + + let result = a.add_mul(k, &b); + // 2*w1 + 10 + 4*(3*w2 + 5) = 2*w1 + 10 + 12*w2 + 20 = 2*w1 + 12*w2 + 30 + assert_eq!(result.to_string(), "2*w1 + 12*w2 + 30"); + } + + #[test] + fn add_mul_complex_expression() { + // Test a complex expression with all types of terms + let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11").unwrap(); + let b = Expression::from_str("w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13").unwrap(); + let k = FieldElement::from(2u128); + + let result = a.add_mul(k, &b); + // 2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11 + 2*(w1*w2 + 4*w5*w6 + 2*w1 + 6*w5 + 13) + // = 2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w3 + 11 + 2*w1*w2 + 8*w5*w6 + 4*w1 + 12*w5 + 26 + // = 4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37 + assert_eq!(result.to_string(), "4*w1*w2 + 3*w3*w4 + 8*w5*w6 + 9*w1 + 7*w3 + 12*w5 + 37"); + } + #[test] fn display_zero() { let zero = Expression::::default(); diff --git a/acvm-repo/acir/src/native_types/expression/operators.rs b/acvm-repo/acir/src/native_types/expression/operators.rs index 6e475aa357b..41fe18760f0 100644 --- a/acvm-repo/acir/src/native_types/expression/operators.rs +++ b/acvm-repo/acir/src/native_types/expression/operators.rs @@ -1,9 +1,6 @@ use crate::native_types::Witness; use acir_field::AcirField; -use std::{ - cmp::Ordering, - ops::{Add, Mul, Neg, Sub}, -}; +use std::ops::{Add, Mul, Neg, Sub}; use super::Expression; @@ -12,16 +9,34 @@ use super::Expression; impl Neg for &Expression { type Output = Expression; fn neg(self) -> Self::Output { - // XXX(med) : Implement an efficient way to do this + let mut mul_terms = self.mul_terms.clone(); + for (q_m, _, _) in &mut mul_terms { + *q_m = -*q_m; + } - let mul_terms: Vec<_> = - self.mul_terms.iter().map(|(q_m, w_l, w_r)| (-*q_m, *w_l, *w_r)).collect(); + let mut linear_combinations = self.linear_combinations.clone(); + for (q_k, _) in &mut linear_combinations { + *q_k = -*q_k; + } + + Expression { mul_terms, linear_combinations, q_c: -self.q_c } + } +} + +impl Neg for Expression { + type Output = Expression; + fn neg(mut self) -> Self::Output { + for (q_m, _, _) in &mut self.mul_terms { + *q_m = -*q_m; + } + + for (q_k, _) in &mut self.linear_combinations { + *q_k = -*q_k; + } - let linear_combinations: Vec<_> = - self.linear_combinations.iter().map(|(q_k, w_k)| (-*q_k, *w_k)).collect(); - let q_c = -self.q_c; + self.q_c = -self.q_c; - Expression { mul_terms, linear_combinations, q_c } + self } } @@ -128,62 +143,34 @@ impl Mul<&Expression> for &Expression { return None; } + // Start with the constant term: q_c_self * q_c_rhs let mut output = Expression::from_field(self.q_c * rhs.q_c); - //TODO to optimize... + // 'each linear term in self' * 'each linear term in rhs' + // XXX: This has a quadratic cost that can be improved, but for now we favor simplicity. for lc in &self.linear_combinations { let single = single_mul(lc.1, rhs); output = output.add_mul(lc.0, &single); } - //linear terms - let mut i1 = 0; //a - let mut i2 = 0; //b - while i1 < self.linear_combinations.len() && i2 < rhs.linear_combinations.len() { - let (a_c, a_w) = self.linear_combinations[i1]; - let (b_c, b_w) = rhs.linear_combinations[i2]; - - // Apply scaling from multiplication - let a_c = rhs.q_c * a_c; - let b_c = self.q_c * b_c; - - let (coeff, witness) = match a_w.cmp(&b_w) { - Ordering::Greater => { - i2 += 1; - (b_c, b_w) - } - Ordering::Less => { - i1 += 1; - (a_c, a_w) - } - Ordering::Equal => { - // Here we're taking both terms as the witness indices are equal. - // We then advance both `i1` and `i2`. - i1 += 1; - i2 += 1; - (a_c + b_c, a_w) - } + // Add linear terms from self scaled by rhs's constant: self.linear * rhs.q_c + if !rhs.q_c.is_zero() { + let self_linear = Expression { + mul_terms: Vec::new(), + linear_combinations: self.linear_combinations.clone(), + q_c: F::zero(), }; - - if !coeff.is_zero() { - output.linear_combinations.push((coeff, witness)); - } + output = output.add_mul(rhs.q_c, &self_linear); } - while i1 < self.linear_combinations.len() { - let (a_c, a_w) = self.linear_combinations[i1]; - let coeff = rhs.q_c * a_c; - if !coeff.is_zero() { - output.linear_combinations.push((coeff, a_w)); - } - i1 += 1; - } - while i2 < rhs.linear_combinations.len() { - let (b_c, b_w) = rhs.linear_combinations[i2]; - let coeff = self.q_c * b_c; - if !coeff.is_zero() { - output.linear_combinations.push((coeff, b_w)); - } - i2 += 1; + + // Add linear terms from rhs scaled by self's constant: rhs.linear * self.q_c + if !self.q_c.is_zero() { + let rhs_linear = Expression { + mul_terms: Vec::new(), + linear_combinations: rhs.linear_combinations.clone(), + q_c: F::zero(), + }; + output = output.add_mul(self.q_c, &rhs_linear); } Some(output) @@ -208,6 +195,7 @@ fn single_mul(w: Witness, b: &Expression) -> Expression { #[cfg(test)] mod tests { use crate::native_types::Expression; + use acir_field::{AcirField, FieldElement}; #[test] fn add_smoke_test() { @@ -230,4 +218,264 @@ mod tests { // Enforce commutativity assert_eq!(&a * &b, &b * &a); } + + #[test] + fn mul_by_zero_constant() { + // Multiplying by zero should give zero (with zero coefficients) + // Note: The implementation may leave zero-coefficient terms in place + let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap(); + let zero: Expression = Expression::zero(); + + let result = (&a * &zero).unwrap(); + // All terms should have zero coefficients and the constant should be zero + assert!(result.mul_terms.is_empty()); + assert!(result.q_c.is_zero()); + for (coeff, _) in &result.linear_combinations { + assert!(coeff.is_zero()); + } + + // Enforce commutativity + assert_eq!(&a * &zero, &zero * &a); + } + + #[test] + fn mul_by_one_constant() { + // Multiplying by one should give the same expression + let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap(); + let one: Expression = Expression::one(); + + let result = (&a * &one).unwrap(); + assert_eq!(result, a); + + // Enforce commutativity + assert_eq!(&a * &one, &one * &a); + } + + #[test] + fn mul_by_scalar_constant() { + // Multiplying by a constant should scale all terms + let a = Expression::from_str("2*w1 + 3*w2 + 4").unwrap(); + let scalar = Expression::from_field(FieldElement::from(5u128)); + + let result = (&a * &scalar).unwrap(); + assert_eq!(result.to_string(), "10*w1 + 15*w2 + 20"); + + // Enforce commutativity + assert_eq!(&a * &scalar, &scalar * &a); + } + + #[test] + fn mul_two_constants() { + // Multiplying two constants + let a = Expression::from_field(FieldElement::from(3u128)); + let b = Expression::from_field(FieldElement::from(7u128)); + + let result = (&a * &b).unwrap(); + assert_eq!(result, Expression::from_field(FieldElement::from(21u128))); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_linear_expressions() { + // Test multiplication of two linear expressions (no constants) + let a = Expression::from_str("2*w1 + 3*w2").unwrap(); + let b = Expression::from_str("4*w3 + 5*w4").unwrap(); + + let result = (&a * &b).unwrap(); + // (2*w1 + 3*w2) * (4*w3 + 5*w4) = 8*w1*w3 + 10*w1*w4 + 12*w2*w3 + 15*w2*w4 + assert_eq!(result.to_string(), "8*w1*w3 + 10*w1*w4 + 12*w2*w3 + 15*w2*w4"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_with_shared_witness() { + // Test multiplication where both expressions share a witness + let a = Expression::from_str("2*w1 + 3*w2").unwrap(); + let b = Expression::from_str("4*w1 + 5*w3").unwrap(); + + let result = (&a * &b).unwrap(); + // (2*w1 + 3*w2) * (4*w1 + 5*w3) = 8*w1*w1 + 10*w1*w3 + 12*w1*w2 + 15*w2*w3 + assert_eq!(result.to_string(), "8*w1*w1 + 12*w1*w2 + 10*w1*w3 + 15*w2*w3"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_single_witness() { + // Test squaring a single witness: (w1) * (w1) = w1*w1 + let a = Expression::from_str("w1").unwrap(); + let b = Expression::from_str("w1").unwrap(); + + let result = (&a * &b).unwrap(); + assert_eq!(result.to_string(), "w1*w1"); + } + + #[test] + fn mul_with_constant_term() { + // Test multiplication where one expression has a constant term + let a = Expression::from_str("2*w1 + 3").unwrap(); + let b = Expression::from_str("4*w2 + 5").unwrap(); + + let result = (&a * &b).unwrap(); + // (2*w1 + 3) * (4*w2 + 5) = 8*w1*w2 + 10*w1 + 12*w2 + 15 + assert_eq!(result.to_string(), "8*w1*w2 + 10*w1 + 12*w2 + 15"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_degree_two_fails() { + // Multiplying expressions that would result in degree > 2 should return None + let a = Expression::from_str("2*w1*w2 + 3*w1").unwrap(); + let b = Expression::from_str("4*w3 + 5").unwrap(); + + let result = &a * &b; + assert!(result.is_none(), "Multiplication should fail for degree > 2"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_both_degree_two_fails() { + // Multiplying two degree-2 expressions should fail + let a = Expression::from_str("w1*w2").unwrap(); + let b = Expression::from_str("w3*w4").unwrap(); + + let result = &a * &b; + assert!(result.is_none(), "Multiplication of two degree-2 expressions should fail"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_complex_linear_expressions() { + // Test a more complex multiplication + let a = Expression::from_str("2*w1 + 3*w2 + 4*w3 + 5").unwrap(); + let b = Expression::from_str("6*w4 + 7*w5 + 8").unwrap(); + + let result = (&a * &b).unwrap(); + // (2*w1 + 3*w2 + 4*w3 + 5) * (6*w4 + 7*w5 + 8) + // = 12*w1*w4 + 14*w1*w5 + 18*w2*w4 + 21*w2*w5 + 24*w3*w4 + 28*w3*w5 + // + 16*w1 + 24*w2 + 32*w3 + 30*w4 + 35*w5 + 40 + assert_eq!( + result.to_string(), + "12*w1*w4 + 14*w1*w5 + 18*w2*w4 + 21*w2*w5 + 24*w3*w4 + 28*w3*w5 + 16*w1 + 24*w2 + 32*w3 + 30*w4 + 35*w5 + 40" + ); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_witness_ordering() { + // Test that witness pairs are ordered correctly (smaller index first) + let a = Expression::from_str("w5").unwrap(); + let b = Expression::from_str("w2").unwrap(); + + let result = (&a * &b).unwrap(); + // Should be w2*w5, not w5*w2 + assert_eq!(result.to_string(), "w2*w5"); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } + + #[test] + fn mul_result_is_sorted() { + // Verify the witness ordering in mul_terms is correct + let a = Expression::from_str("w3 + w1").unwrap(); + let b = Expression::from_str("w4 + w2").unwrap(); + + let result = (&a * &b).unwrap(); + // Verify that each mul_term has properly ordered witnesses (smaller first) + for (_, wl, wr) in &result.mul_terms { + assert!(wl <= wr, "Witnesses in mul_terms should be ordered"); + } + } + + #[test] + fn neg_reference() { + // Test negation of a reference (uses clone + in-place negate) + let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap(); + let result = -&a; + + assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7"); + + // Original should be unchanged + assert_eq!(a.to_string(), "2*w1*w2 + 3*w1 + 5*w2 + 7"); + } + + #[test] + fn neg_owned() { + // Test negation of an owned expression (in-place, no clone) + let a = Expression::from_str("2*w1*w2 + 3*w1 + 5*w2 + 7").unwrap(); + let result = -a; + + assert_eq!(result.to_string(), "-2*w1*w2 - 3*w1 - 5*w2 - 7"); + } + + #[test] + fn neg_zero() { + // Negating zero should give zero + let zero: Expression = Expression::zero(); + let result = -&zero; + + assert_eq!(result, Expression::zero()); + } + + #[test] + fn neg_constant() { + // Negating a constant expression + let a = Expression::from_field(FieldElement::from(42u128)); + let result = -a; + + assert_eq!(result.q_c, FieldElement::from(-42i128)); + assert!(result.mul_terms.is_empty()); + assert!(result.linear_combinations.is_empty()); + } + + #[test] + fn neg_linear_only() { + // Negating an expression with only linear terms + let a = Expression::from_str("3*w1 + 5*w2 + 7").unwrap(); + let result = -a; + + assert_eq!(result.to_string(), "-3*w1 - 5*w2 - 7"); + } + + #[test] + fn neg_mul_only() { + // Negating an expression with only multiplication terms + let a = Expression::from_str("2*w1*w2 + 4*w3*w4").unwrap(); + let result = -a; + + assert_eq!(result.to_string(), "-2*w1*w2 - 4*w3*w4"); + } + + #[test] + fn double_neg() { + // Double negation should give back the original + let a = Expression::from_str("2*w1*w2 + 3*w1 + 5").unwrap(); + let result = -(-a.clone()); + + assert_eq!(result, a); + } + + #[test] + fn neg_preserves_structure() { + // Negation should preserve the structure (number of terms) + let a = Expression::from_str("2*w1*w2 + 3*w3*w4 + 5*w1 + 7*w2 + 11").unwrap(); + let result = -&a; + + assert_eq!(result.mul_terms.len(), a.mul_terms.len()); + assert_eq!(result.linear_combinations.len(), a.linear_combinations.len()); + } } diff --git a/acvm-repo/acir/src/native_types/witness_map.rs b/acvm-repo/acir/src/native_types/witness_map.rs index 847ba04453d..cae4c53e302 100644 --- a/acvm-repo/acir/src/native_types/witness_map.rs +++ b/acvm-repo/acir/src/native_types/witness_map.rs @@ -117,3 +117,70 @@ impl Deserialize<'a>> WitnessMap { .map_err(|e| WitnessMapError(SerializationError::Deserialize(e))) } } + +#[cfg(test)] +mod tests { + use super::*; + use acir_field::FieldElement; + + #[test] + fn test_round_trip_serialization() { + // Create a witness map with several entries + let mut original = WitnessMap::new(); + original.insert(Witness(0), FieldElement::from(42u128)); + original.insert(Witness(1), FieldElement::from(123u128)); + original.insert(Witness(5), FieldElement::from(999u128)); + original.insert(Witness(10), FieldElement::zero()); + original.insert(Witness(100), FieldElement::one()); + + // Serialize + let serialized = original.serialize().expect("Serialization should succeed"); + + // Deserialize + let deserialized = + WitnessMap::deserialize(&serialized).expect("Deserialization should succeed"); + + // Verify round trip + assert_eq!(original, deserialized); + } + + #[test] + fn test_round_trip_empty_witness_map() { + // Test with an empty witness map + let original = WitnessMap::::new(); + + let serialized = original.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessMap::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_round_trip_single_entry() { + // Test with a single entry + let mut original = WitnessMap::new(); + original.insert(Witness(0), FieldElement::from(12345u128)); + + let serialized = original.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessMap::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_round_trip_large_field_elements() { + // Test with large field elements + let mut original = WitnessMap::new(); + original.insert(Witness(0), FieldElement::from(u128::MAX)); + original.insert(Witness(1), -FieldElement::one()); + original.insert(Witness(2), FieldElement::from(u128::MAX / 2)); + + let serialized = original.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessMap::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(original, deserialized); + } +} diff --git a/acvm-repo/acir/src/native_types/witness_stack.rs b/acvm-repo/acir/src/native_types/witness_stack.rs index 4207a19c04a..368f85e4abc 100644 --- a/acvm-repo/acir/src/native_types/witness_stack.rs +++ b/acvm-repo/acir/src/native_types/witness_stack.rs @@ -105,3 +105,133 @@ impl From> for WitnessStack { Self { stack } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::native_types::Witness; + use acir_field::FieldElement; + + #[test] + fn test_round_trip_serialization() { + // Create a witness stack with multiple stack items + let mut stack = WitnessStack::default(); + + // First function call with some witnesses + let mut witness1 = WitnessMap::new(); + witness1.insert(Witness(0), FieldElement::from(42u128)); + witness1.insert(Witness(1), FieldElement::from(123u128)); + stack.push(0, witness1); + + // Second function call with different witnesses + let mut witness2 = WitnessMap::new(); + witness2.insert(Witness(0), FieldElement::from(999u128)); + witness2.insert(Witness(5), FieldElement::zero()); + stack.push(1, witness2); + + // Third function call + let mut witness3 = WitnessMap::new(); + witness3.insert(Witness(10), FieldElement::one()); + witness3.insert(Witness(20), FieldElement::from(u128::MAX)); + stack.push(2, witness3); + + // Serialize + let serialized = stack.serialize().expect("Serialization should succeed"); + + // Deserialize + let deserialized = + WitnessStack::deserialize(&serialized).expect("Deserialization should succeed"); + + // Verify round trip + assert_eq!(stack, deserialized); + } + + #[test] + fn test_round_trip_empty_witness_stack() { + // Test with an empty witness stack + let original = WitnessStack::::default(); + + let serialized = original.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessStack::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_round_trip_single_stack_item() { + // Test with a single stack item + let mut stack = WitnessStack::default(); + let mut witness = WitnessMap::new(); + witness.insert(Witness(0), FieldElement::from(12345u128)); + witness.insert(Witness(1), FieldElement::from(67890u128)); + stack.push(0, witness); + + let serialized = stack.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessStack::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(stack, deserialized); + } + + #[test] + fn test_round_trip_from_witness_map() { + // Test conversion from WitnessMap and serialization + let mut witness = WitnessMap::new(); + witness.insert(Witness(0), FieldElement::from(111u128)); + witness.insert(Witness(1), FieldElement::from(222u128)); + witness.insert(Witness(2), FieldElement::from(333u128)); + + let original = WitnessStack::from(witness); + + let serialized = original.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessStack::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_round_trip_large_stack() { + // Test with many stack items + let mut stack = WitnessStack::default(); + + for i in 0..10 { + let mut witness = WitnessMap::new(); + witness.insert(Witness(i), FieldElement::from(u128::from(i) * 100)); + witness.insert(Witness(i + 100), FieldElement::from(u128::from(i) * 1000)); + stack.push(i, witness); + } + + let serialized = stack.serialize().expect("Serialization should succeed"); + let deserialized = + WitnessStack::deserialize(&serialized).expect("Deserialization should succeed"); + + assert_eq!(stack, deserialized); + } + + #[test] + fn test_stack_operations() { + // Test stack operations work correctly + let mut stack = WitnessStack::default(); + + let mut witness1 = WitnessMap::new(); + witness1.insert(Witness(0), FieldElement::from(1u128)); + stack.push(0, witness1.clone()); + + let mut witness2 = WitnessMap::new(); + witness2.insert(Witness(1), FieldElement::from(2u128)); + stack.push(1, witness2.clone()); + + assert_eq!(stack.length(), 2); + assert_eq!(stack.peek().unwrap().index, 1); + + let popped = stack.pop().unwrap(); + assert_eq!(popped.index, 1); + assert_eq!(popped.witness, witness2); + + assert_eq!(stack.length(), 1); + assert_eq!(stack.peek().unwrap().index, 0); + assert_eq!(stack.length(), 1); + } +}