Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 46 additions & 59 deletions acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use acir::brillig::{BinaryFieldOp, BinaryIntOp, IntegerBitSize};
use acir::AcirField;
use num_bigint::BigUint;
use num_traits::{AsPrimitive, PrimInt, WrappingAdd, WrappingMul, WrappingSub};

use crate::memory::{MemoryTypeError, MemoryValue};

Expand Down Expand Up @@ -93,10 +94,14 @@ pub(crate) fn evaluate_binary_int_op<F: AcirField>(
}
})?;

let result = if bit_size == IntegerBitSize::U128 {
evaluate_binary_int_op_128(op, lhs, rhs)?
} else {
evaluate_binary_int_op_generic(op, lhs, rhs, bit_size)?
// `lhs` and `rhs` are asserted to fit within their given types when being read from memory so this is safe.
let result = match bit_size {
IntegerBitSize::U1 => evaluate_binary_int_op_u1(op, lhs != 0, rhs != 0)?.into(),
IntegerBitSize::U8 => evaluate_binary_int_op_num(op, lhs as u8, rhs as u8, 8)?.into(),
IntegerBitSize::U16 => evaluate_binary_int_op_num(op, lhs as u16, rhs as u16, 16)?.into(),
IntegerBitSize::U32 => evaluate_binary_int_op_num(op, lhs as u32, rhs as u32, 32)?.into(),
IntegerBitSize::U64 => evaluate_binary_int_op_num(op, lhs as u64, rhs as u64, 64)?.into(),
IntegerBitSize::U128 => evaluate_binary_int_op_num(op, lhs, rhs, 128)?,
};

Ok(match op {
Expand All @@ -107,89 +112,71 @@ pub(crate) fn evaluate_binary_int_op<F: AcirField>(
})
}

fn evaluate_binary_int_op_128(
fn evaluate_binary_int_op_u1(
op: &BinaryIntOp,
lhs: u128,
rhs: u128,
) -> Result<u128, BrilligArithmeticError> {
lhs: bool,
rhs: bool,
) -> Result<bool, BrilligArithmeticError> {
let result = match op {
BinaryIntOp::Add => lhs.wrapping_add(rhs),
BinaryIntOp::Sub => lhs.wrapping_sub(rhs),
BinaryIntOp::Mul => lhs.wrapping_mul(rhs),
BinaryIntOp::Add | BinaryIntOp::Sub => lhs ^ rhs,
BinaryIntOp::Mul => lhs & rhs,
BinaryIntOp::Div => {
if rhs == 0 {
if !rhs {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
lhs / rhs
lhs
}
}
BinaryIntOp::Equals => (lhs == rhs) as u128,
BinaryIntOp::LessThan => (lhs < rhs) as u128,
BinaryIntOp::LessThanEquals => (lhs <= rhs) as u128,
BinaryIntOp::Equals => lhs == rhs,
BinaryIntOp::LessThan => !lhs & rhs,
BinaryIntOp::LessThanEquals => lhs <= rhs,
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl => {
if rhs >= 128 {
0
} else {
lhs.wrapping_shl(rhs as u32)
}
}
BinaryIntOp::Shr => {
if rhs >= 128 {
0
BinaryIntOp::Shl | BinaryIntOp::Shr => {
if rhs {
false
} else {
lhs.wrapping_shr(rhs as u32)
lhs
}
}
};
Ok(result)
}

fn evaluate_binary_int_op_generic(
fn evaluate_binary_int_op_num<
T: PrimInt + AsPrimitive<usize> + From<bool> + WrappingAdd + WrappingSub + WrappingMul,
>(
op: &BinaryIntOp,
lhs: u128,
rhs: u128,
bit_size: IntegerBitSize,
) -> Result<u128, BrilligArithmeticError> {
let bit_size: u32 = bit_size.into();
let bit_modulo = 1 << bit_size;
lhs: T,
rhs: T,
num_bits: usize,
) -> Result<T, BrilligArithmeticError> {
let result = match op {
// Perform addition, subtraction, and multiplication, applying a modulo operation to keep the result within the bit size.
BinaryIntOp::Add => (lhs + rhs) % bit_modulo,
BinaryIntOp::Sub => (bit_modulo + lhs - rhs) % bit_modulo,
BinaryIntOp::Mul => (lhs * rhs) % bit_modulo,
// Perform unsigned division using the modulo operation on a and b.
BinaryIntOp::Div => {
if rhs == 0 {
return Err(BrilligArithmeticError::DivisionByZero);
} else {
lhs / rhs
}
}
// Perform a == operation, returning 0 or 1
BinaryIntOp::Equals => (lhs == rhs) as u128,
// Perform a < operation, returning 0 or 1
BinaryIntOp::LessThan => (lhs < rhs) as u128,
// Perform a <= operation, returning 0 or 1
BinaryIntOp::LessThanEquals => (lhs <= rhs) as u128,
// Perform bitwise AND, OR, XOR, left shift, and right shift operations, applying a modulo operation to keep the result within the bit size.
BinaryIntOp::Add => lhs.wrapping_add(&rhs),
BinaryIntOp::Sub => lhs.wrapping_sub(&rhs),
BinaryIntOp::Mul => lhs.wrapping_mul(&rhs),
BinaryIntOp::Div => lhs.checked_div(&rhs).ok_or(BrilligArithmeticError::DivisionByZero)?,
BinaryIntOp::Equals => (lhs == rhs).into(),
BinaryIntOp::LessThan => (lhs < rhs).into(),
BinaryIntOp::LessThanEquals => (lhs <= rhs).into(),
BinaryIntOp::And => lhs & rhs,
BinaryIntOp::Or => lhs | rhs,
BinaryIntOp::Xor => lhs ^ rhs,
BinaryIntOp::Shl => {
if rhs >= (bit_size as u128) {
0
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
(lhs << rhs) % bit_modulo
lhs << rhs_usize
}
}
BinaryIntOp::Shr => {
if rhs >= (bit_size as u128) {
0
let rhs_usize = rhs.as_();
if rhs_usize >= num_bits {
T::zero()
} else {
lhs >> rhs
lhs >> rhs_usize
}
}
};
Expand Down