diff --git a/Cargo.lock b/Cargo.lock index 74f916e8a81..ada55410c6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,7 +85,6 @@ dependencies = [ "blake3", "cbc", "criterion", - "itertools 0.14.0", "k256", "keccak 0.2.0-rc.2", "log", diff --git a/acvm-repo/blackbox_solver/Cargo.toml b/acvm-repo/blackbox_solver/Cargo.toml index 00012bd300a..04c19b0450e 100644 --- a/acvm-repo/blackbox_solver/Cargo.toml +++ b/acvm-repo/blackbox_solver/Cargo.toml @@ -17,7 +17,6 @@ workspace = true [dependencies] acir.workspace = true -itertools.workspace = true thiserror.workspace = true log.workspace = true diff --git a/acvm-repo/blackbox_solver/benches/logic.rs b/acvm-repo/blackbox_solver/benches/logic.rs index e835fea8622..f47afc7c3fd 100644 --- a/acvm-repo/blackbox_solver/benches/logic.rs +++ b/acvm-repo/blackbox_solver/benches/logic.rs @@ -1,7 +1,7 @@ use criterion::{Criterion, black_box, criterion_group, criterion_main}; use std::time::Duration; -use acir::FieldElement; +use acir::{AcirField, FieldElement}; use pprof::criterion::{Output, PProfProfiler}; @@ -13,7 +13,7 @@ fn bench_logic_ops(c: &mut Criterion) { let mut group = c.benchmark_group("logic_ops"); - for &bits in &[8u32, 32u32, 64u32] { + for &bits in &[8u32, 32u32, 64u32, 128u32] { group.bench_function(format!("bit_and_{bits}bits"), |b| { b.iter(|| { let _ = bit_and(black_box(lhs), black_box(rhs), black_box(bits)); @@ -27,6 +27,23 @@ fn bench_logic_ops(c: &mut Criterion) { }); } + // Bench with large field elements (values > 128 bits) to exercise the byte-level fallback. + let large_lhs = FieldElement::from_be_bytes_reduce(&[0xFFu8; 32]); + let large_rhs = FieldElement::from_be_bytes_reduce(&[0xAAu8; 32]); + let full_bits = FieldElement::max_num_bits(); + + group.bench_function(format!("bit_and_{full_bits}bits_large"), |b| { + b.iter(|| { + let _ = bit_and(black_box(large_lhs), black_box(large_rhs), black_box(full_bits)); + }); + }); + + group.bench_function(format!("bit_xor_{full_bits}bits_large"), |b| { + b.iter(|| { + let _ = bit_xor(black_box(large_lhs), black_box(large_rhs), black_box(full_bits)); + }); + }); + group.finish(); } diff --git a/acvm-repo/blackbox_solver/src/logic.rs b/acvm-repo/blackbox_solver/src/logic.rs index 993aef73e58..39c9ea8cbf5 100644 --- a/acvm-repo/blackbox_solver/src/logic.rs +++ b/acvm-repo/blackbox_solver/src/logic.rs @@ -1,12 +1,40 @@ use acir::AcirField; -use itertools::Itertools; pub fn bit_and(lhs: F, rhs: F, num_bits: u32) -> F { - bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte & rhs_byte) + // Fast path: use native u128 operations when the bit width fits, + // avoiding all heap allocations from field-to-byte conversions. + if let Some(result) = try_bitwise_u128(lhs, rhs, num_bits, |l, r| l & r) { + result + } else { + bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte & rhs_byte) + } } pub fn bit_xor(lhs: F, rhs: F, num_bits: u32) -> F { - bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte ^ rhs_byte) + // Fast path: use native u128 operations when the bit width fits, + // avoiding all heap allocations from field-to-byte conversions. + if let Some(result) = try_bitwise_u128(lhs, rhs, num_bits, |l, r| l ^ r) { + result + } else { + bitwise_op(lhs, rhs, num_bits, |lhs_byte, rhs_byte| lhs_byte ^ rhs_byte) + } +} + +/// Attempt to perform a bitwise operation using native u128 arithmetic. +/// Returns `None` if `num_bits > 128` or either operand doesn't fit in a u128. +fn try_bitwise_u128( + lhs: F, + rhs: F, + num_bits: u32, + op: fn(u128, u128) -> u128, +) -> Option { + if num_bits > 128 { + return None; + } + let l = lhs.try_into_u128()?; + let r = rhs.try_into_u128()?; + let mask = if num_bits >= 128 { u128::MAX } else { (1u128 << num_bits) - 1 }; + Some(F::from(op(l, r) & mask)) } /// Performs a bitwise operation on two field elements by treating them as byte arrays. @@ -16,16 +44,15 @@ pub fn bit_xor(lhs: F, rhs: F, num_bits: u32) -> F { /// and the result is converted back to a field element. /// This function works for any `num_bits` value and does not assume it to be a multiple of 8. fn bitwise_op(lhs: F, rhs: F, num_bits: u32, op: fn(u8, u8) -> u8) -> F { - // We could explicitly expect `num_bits` to be a multiple of 8 as most backends assume bytes: - // assert!(num_bits % 8 == 0, "num_bits is not a multiple of 8, it is {num_bits}"); - - let lhs_bytes = mask_to_le_bytes(lhs, num_bits); + let mut lhs_bytes = mask_to_le_bytes(lhs, num_bits); let rhs_bytes = mask_to_le_bytes(rhs, num_bits); - let and_byte_arr: Vec<_> = - lhs_bytes.into_iter().zip_eq(rhs_bytes).map(|(left, right)| op(left, right)).collect(); + // Operate in-place on lhs_bytes to avoid allocating a third Vec. + for (l, r) in lhs_bytes.iter_mut().zip(rhs_bytes.iter()) { + *l = op(*l, *r); + } - F::from_le_bytes_reduce(&and_byte_arr) + F::from_le_bytes_reduce(&lhs_bytes) } // mask_to methods will not remove any bytes from the field