Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .noir-sync-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
55545d630a5b338cf97068d23695779c32e5109b
c3deb6ab504df75ae8c90d483d53083c6cd8d443
260 changes: 174 additions & 86 deletions noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use acir::brillig::{BinaryFieldOp, BinaryIntOp, IntegerBitSize};
use std::ops::{BitAnd, BitOr, BitXor, Shl, Shr};

use acir::brillig::{BinaryFieldOp, BinaryIntOp, BitSize, IntegerBitSize};
use acir::AcirField;
use num_bigint::BigUint;
use num_traits::{AsPrimitive, PrimInt, WrappingAdd, WrappingMul, WrappingSub};
use num_traits::{CheckedDiv, WrappingAdd, WrappingMul, WrappingSub, Zero};

use crate::memory::{MemoryTypeError, MemoryValue};

Expand All @@ -21,24 +23,20 @@ pub(crate) fn evaluate_binary_field_op<F: AcirField>(
lhs: MemoryValue<F>,
rhs: MemoryValue<F>,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let a = match lhs {
MemoryValue::Field(a) => a,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: bit_size.into(),
op_bit_size: F::max_num_bits(),
});
let a = *lhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
};
let b = match rhs {
MemoryValue::Field(b) => b,
MemoryValue::Integer(_, bit_size) => {
return Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: bit_size.into(),
op_bit_size: F::max_num_bits(),
});
})?;
let b = *rhs.expect_field().map_err(|err| {
let MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } = err;
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
};
})?;

Ok(match op {
// Perform addition, subtraction, multiplication, and division based on the BinaryOp variant.
Expand Down Expand Up @@ -70,46 +68,120 @@ pub(crate) fn evaluate_binary_int_op<F: AcirField>(
rhs: MemoryValue<F>,
bit_size: IntegerBitSize,
) -> Result<MemoryValue<F>, BrilligArithmeticError> {
let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
match op {
BinaryIntOp::Add
| BinaryIntOp::Sub
| BinaryIntOp::Mul
| BinaryIntOp::Div
| BinaryIntOp::And
| BinaryIntOp::Or
| BinaryIntOp::Xor => match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
evaluate_binary_int_op_u1(op, lhs, rhs).map(MemoryValue::U1)
}
}
})?;

let rhs_bit_size = if op == &BinaryIntOp::Shl || op == &BinaryIntOp::Shr {
IntegerBitSize::U8
} else {
bit_size
};
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U8)
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U16)
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U32)
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U64)
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
evaluate_binary_int_op_arith(op, lhs, rhs).map(MemoryValue::U128)
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
},

let rhs = rhs.expect_integer_with_bit_size(rhs_bit_size).map_err(|err| match err {
MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
match (lhs, rhs, bit_size) {
(MemoryValue::U1(lhs), MemoryValue::U1(rhs), IntegerBitSize::U1) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U8(lhs), MemoryValue::U8(rhs), IntegerBitSize::U8) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U16(lhs), MemoryValue::U16(rhs), IntegerBitSize::U16) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U32(lhs), MemoryValue::U32(rhs), IntegerBitSize::U32) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U64(lhs), MemoryValue::U64(rhs), IntegerBitSize::U64) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(MemoryValue::U128(lhs), MemoryValue::U128(rhs), IntegerBitSize::U128) => {
Ok(MemoryValue::U1(evaluate_binary_int_op_cmp(op, lhs, rhs)))
}
(lhs, _, _) if lhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
(_, rhs, _) if rhs.bit_size() != BitSize::Integer(bit_size) => {
Err(BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: rhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
})
}
_ => unreachable!("Invalid arguments are covered by the two arms above."),
}
}
})?;

// `lhs` and `rhs` are asserted to fit within their given types when being read from memory so this is safe.
let result = match bit_size {
IntegerBitSize::U1 => evaluate_binary_int_op_u1(op, lhs != 0, rhs != 0)?.into(),
IntegerBitSize::U8 => evaluate_binary_int_op_num(op, lhs as u8, rhs as u8, 8)?.into(),
IntegerBitSize::U16 => evaluate_binary_int_op_num(op, lhs as u16, rhs as u16, 16)?.into(),
IntegerBitSize::U32 => evaluate_binary_int_op_num(op, lhs as u32, rhs as u32, 32)?.into(),
IntegerBitSize::U64 => evaluate_binary_int_op_num(op, lhs as u64, rhs as u64, 64)?.into(),
IntegerBitSize::U128 => evaluate_binary_int_op_num(op, lhs, rhs, 128)?,
};

Ok(match op {
BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => {
MemoryValue::new_integer(result, IntegerBitSize::U1)
BinaryIntOp::Shl | BinaryIntOp::Shr => {
let rhs = rhs.expect_u8().map_err(
|MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size }| {
BrilligArithmeticError::MismatchedRhsBitSize {
rhs_bit_size: value_bit_size,
op_bit_size: expected_bit_size,
}
},
)?;

match (lhs, bit_size) {
(MemoryValue::U1(lhs), IntegerBitSize::U1) => {
let result = if rhs == 0 { lhs } else { false };
Ok(MemoryValue::U1(result))
}
(MemoryValue::U8(lhs), IntegerBitSize::U8) => {
Ok(MemoryValue::U8(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U16(lhs), IntegerBitSize::U16) => {
Ok(MemoryValue::U16(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U32(lhs), IntegerBitSize::U32) => {
Ok(MemoryValue::U32(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U64(lhs), IntegerBitSize::U64) => {
Ok(MemoryValue::U64(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
(MemoryValue::U128(lhs), IntegerBitSize::U128) => {
Ok(MemoryValue::U128(evaluate_binary_int_op_shifts(op, lhs, rhs)))
}
_ => Err(BrilligArithmeticError::MismatchedLhsBitSize {
lhs_bit_size: lhs.bit_size().to_u32::<F>(),
op_bit_size: bit_size.into(),
}),
}
}
_ => MemoryValue::new_integer(result, bit_size),
})
}
}

fn evaluate_binary_int_op_u1(
Expand All @@ -118,67 +190,83 @@ fn evaluate_binary_int_op_u1(
rhs: bool,
) -> Result<bool, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And | BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor | BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Div => {
if !rhs {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
lhs
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}

fn evaluate_binary_int_op_cmp<T: Ord + PartialEq>(op: &BinaryIntOp, lhs: T, rhs: T) -> bool {
match op {
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThan => lhs < rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl | BinaryIntOp::Shr => {
if rhs {
false
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}

fn evaluate_binary_int_op_shifts<T: From<u8> + Zero + Shl<Output = T> + Shr<Output = T>>(
op: &BinaryIntOp,
lhs: T,
rhs: u8,
) -> T {
match op {
BinaryIntOp::Shl => {
let rhs_usize: usize = rhs as usize;
#[allow(unused_qualifications)]
if rhs_usize >= 8 * std::mem::size_of::<T>() {
T::zero()
} else {
lhs
lhs << rhs.into()
}
}
};
Ok(result)
BinaryIntOp::Shr => {
let rhs_usize: usize = rhs as usize;
#[allow(unused_qualifications)]
if rhs_usize >= 8 * std::mem::size_of::<T>() {
T::zero()
} else {
lhs >> rhs.into()
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
}
}

fn evaluate_binary_int_op_num<
T: PrimInt + AsPrimitive<usize> + From<bool> + WrappingAdd + WrappingSub + WrappingMul,
fn evaluate_binary_int_op_arith<
T: WrappingAdd
+ WrappingSub
+ WrappingMul
+ CheckedDiv
+ BitAnd<Output = T>
+ BitOr<Output = T>
+ BitXor<Output = T>,
>(
op: &BinaryIntOp,
lhs: T,
rhs: T,
num_bits: usize,
) -> Result<T, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add => lhs.wrapping_add(&rhs),
BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
BinaryIntOp::Equals => (lhs == rhs).into(),
BinaryIntOp::LessThan => (lhs < rhs).into(),
BinaryIntOp::LessThanEquals => (lhs <= rhs).into(),
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl => {
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
lhs << rhs_usize
}
}
BinaryIntOp::Shr => {
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
lhs >> rhs_usize
}
}
_ => unreachable!("Operator not handled by this function: {op:?}"),
};
Ok(result)
}
Expand Down
24 changes: 9 additions & 15 deletions noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, IntegerBitSize};
use acir::brillig::{BlackBoxOp, HeapArray, HeapVector};
use acir::{AcirField, BlackBoxFunc};
use acvm_blackbox_solver::{
aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccakf1600,
Expand Down Expand Up @@ -312,16 +312,13 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
}
BlackBoxOp::ToRadix { input, radix, output_pointer, num_limbs, output_bits } => {
let input: F = *memory.read(*input).extract_field().expect("ToRadix input not a field");
let radix = memory
.read(*radix)
.expect_integer_with_bit_size(IntegerBitSize::U32)
.expect("ToRadix opcode's radix bit size does not match expected bit size 32");
let MemoryValue::U32(radix) = memory.read(*radix) else {
panic!("ToRadix opcode's radix bit size does not match expected bit size 32")
};
let num_limbs = memory.read(*num_limbs).to_usize();
let output_bits = !memory
.read(*output_bits)
.expect_integer_with_bit_size(IntegerBitSize::U1)
.expect("ToRadix opcode's output_bits size does not match expected bit size 1")
.is_zero();
let MemoryValue::U1(output_bits) = memory.read(*output_bits) else {
panic!("ToRadix opcode's output_bits size does not match expected bit size 1")
};

let mut input = BigUint::from_bytes_be(&input.to_be_bytes());
let radix = BigUint::from_bytes_be(&radix.to_be_bytes());
Expand Down Expand Up @@ -349,13 +346,10 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
for i in (0..num_limbs).rev() {
let limb = &input % &radix;
if output_bits {
limbs[i] = MemoryValue::new_integer(
if limb.is_zero() { 0 } else { 1 },
IntegerBitSize::U1,
);
limbs[i] = MemoryValue::U1(!limb.is_zero());
} else {
let limb: u8 = limb.try_into().unwrap();
limbs[i] = MemoryValue::new_integer(limb as u128, IntegerBitSize::U8);
limbs[i] = MemoryValue::U8(limb);
};
input /= &radix;
}
Expand Down
Loading