diff --git a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs index afe0aa1f1f2..06333fa0174 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs @@ -597,10 +597,8 @@ impl AcirContext { self.euclidean_division_var(lhs, rhs, bit_size, predicate)?; Ok(quotient_var) } - NumericType::Signed { bit_size } => { - let (quotient_var, _remainder_var) = - self.signed_division_var(lhs, rhs, bit_size, predicate)?; - Ok(quotient_var) + NumericType::Signed { .. } => { + unreachable!("Signed division should have been removed before ACIRgen") } } } @@ -1083,96 +1081,6 @@ impl AcirContext { Ok(()) } - // Returns the 2-complement of lhs, using the provided sign bit in 'leading' - // if leading is zero, it returns lhs - // if leading is one, it returns 2^bit_size-lhs - fn two_complement( - &mut self, - lhs: AcirVar, - leading: AcirVar, - max_bit_size: u32, - ) -> Result { - let max_power_of_two = self.add_constant(power_of_two::(max_bit_size - 1)); - - let intermediate = self.sub_var(max_power_of_two, lhs)?; - let intermediate = self.mul_var(intermediate, leading)?; - - self.add_mul_var(lhs, F::from(2_u128), intermediate) - } - - /// Returns the quotient and remainder such that lhs = rhs * quotient + remainder - /// and |remainder| < |rhs| - /// and remainder has the same sign than lhs - /// Note that this is not the euclidean division, where we have instead remainder < |rhs| - fn signed_division_var( - &mut self, - lhs: AcirVar, - rhs: AcirVar, - bit_size: u32, - predicate: AcirVar, - ) -> Result<(AcirVar, AcirVar), RuntimeError> { - // We derive the signed division from the unsigned euclidean division. - // note that this is not euclidean division! - // If `x` is a signed integer, then `sign(x)x >= 0` - // so if `a` and `b` are signed integers, we can do the unsigned division: - // `sign(a)a = q1*sign(b)b + r1` - // => `a = sign(a)sign(b)q1*b + sign(a)r1` - // => `a = qb+r`, with `|r|<|b|` and `a` and `r` have the same sign. - - assert_ne!(bit_size, 0, "signed integer should have at least one bit"); - - // 2^{max_bit size-1} - let max_power_of_two = self.add_constant(power_of_two::(bit_size - 1)); - let zero = self.add_constant(F::zero()); - let one = self.add_constant(F::one()); - - // Get the sign bit of rhs by computing rhs / max_power_of_two - let (rhs_leading, _) = self.euclidean_division_var(rhs, max_power_of_two, bit_size, one)?; - - // Get the sign bit of lhs by computing lhs / max_power_of_two - let (lhs_leading, _) = self.euclidean_division_var(lhs, max_power_of_two, bit_size, one)?; - - // Signed to unsigned: - let unsigned_lhs = self.two_complement(lhs, lhs_leading, bit_size)?; - let unsigned_rhs = self.two_complement(rhs, rhs_leading, bit_size)?; - - // Performs the division using the unsigned values of lhs and rhs - let (q1, r1) = - self.euclidean_division_var(unsigned_lhs, unsigned_rhs, bit_size, predicate)?; - - // Unsigned to signed: derive q and r from q1,r1 and the signs of lhs and rhs - // Quotient sign is lhs sign * rhs sign, whose resulting sign bit is the XOR of the sign bits - let q_sign = self.xor_var(lhs_leading, rhs_leading, AcirType::unsigned(1))?; - let quotient = self.two_complement(q1, q_sign, bit_size)?; - let remainder = self.two_complement(r1, lhs_leading, bit_size)?; - - // Issue #5129 - When q1 is zero and quotient sign is -1, we compute -0=2^{bit_size}, - // which is not valid because we do not wrap integer operations - // Similar case can happen with the remainder. - let q_is_0 = self.eq_var(q1, zero)?; - let q_is_not_0 = self.not_var(q_is_0, AcirType::unsigned(1))?; - let quotient = self.mul_var(quotient, q_is_not_0)?; - let r_is_0 = self.eq_var(r1, zero)?; - let r_is_not_0 = self.not_var(r_is_0, AcirType::unsigned(1))?; - let remainder = self.mul_var(remainder, r_is_not_0)?; - - // The quotient must be a valid signed integer. - // For instance -128/-1 = 128, but 128 is not a valid i8 - // Because it is the only possible overflow that can happen due to signed representation, - // we simply check for this case: quotient is negative, or distinct from 2^{bit_size-1} - // Indeed, negative quotient cannot 'overflow' because the division will not increase its absolute value - let assert_message = - self.generate_assertion_message_payload("Attempt to divide with overflow".to_string()); - let unsigned = self.not_var(q_sign, AcirType::unsigned(1))?; - - // This overflow check must also be under the predicate - let unsigned = self.mul_var(unsigned, predicate)?; - - self.assert_neq_var(quotient, max_power_of_two, unsigned, Some(assert_message))?; - - Ok((quotient, remainder)) - } - /// Returns a variable which is constrained to be `lhs mod rhs` pub(crate) fn modulo_var( &mut self, @@ -1190,8 +1098,8 @@ impl AcirContext { }; let (_, remainder_var) = match numeric_type { - NumericType::Signed { bit_size } => { - self.signed_division_var(lhs, rhs, bit_size, predicate)? + NumericType::Signed { .. } => { + unreachable!("Signed modulo should have been removed before ACIRgen") } _ => self.euclidean_division_var(lhs, rhs, bit_size, predicate)?, }; @@ -1257,58 +1165,6 @@ impl AcirContext { Ok(remainder) } - /// Returns an 'AcirVar' containing the boolean value lhs diff<2^n, because the 2-complement representation keeps the ordering (e.g in 8 bits -1 is 255 > -2 = 254) - /// If not, lhs positive => diff > 2^n - /// and lhs negative => diff <= 2^n => diff < 2^n (because signs are not the same, so lhs != rhs and so diff != 2^n) - pub(crate) fn less_than_signed( - &mut self, - lhs: AcirVar, - rhs: AcirVar, - bit_count: u32, - ) -> Result { - let pow_last = self.add_constant(F::from(1_u128 << (bit_count - 1))); - let pow = self.add_constant(F::from(1_u128 << (bit_count))); - - // We check whether the inputs have same sign or not by computing the XOR of their bit sign - - // Predicate is always active as `pow_last` is known to be non-zero. - let one = self.add_constant(1_u128); - let lhs_sign = self.div_var( - lhs, - pow_last, - AcirType::NumericType(NumericType::Unsigned { bit_size: bit_count }), - one, - )?; - let rhs_sign = self.div_var( - rhs, - pow_last, - AcirType::NumericType(NumericType::Unsigned { bit_size: bit_count }), - one, - )?; - let same_sign = self.xor_var( - lhs_sign, - rhs_sign, - AcirType::NumericType(NumericType::Unsigned { bit_size: 1 }), - )?; - - // We compute the input difference - let no_underflow = self.add_var(lhs, pow)?; - let diff = self.sub_var(no_underflow, rhs)?; - - // We check the 'bit sign' of the difference - let diff_sign = self.less_than_var(diff, pow, bit_count + 1)?; - - // Then the result is simply diff_sign XOR same_sign (can be checked with a truth table) - self.xor_var( - diff_sign, - same_sign, - AcirType::NumericType(NumericType::Unsigned { bit_size: 1 }), - ) - } - /// Returns an `AcirVar` which will be `1` if lhs >= rhs /// and `0` otherwise. pub(crate) fn more_than_eq_var( diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 0b057438d5e..a0ce13b1cb4 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -763,7 +763,7 @@ impl<'a> Context<'a> { BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs), BinaryOp::Lt => match binary_type { AcirType::NumericType(NumericType::Signed { .. }) => { - self.acir_context.less_than_signed(lhs, rhs, bit_count) + panic!("ICE - signed less than should have been removed before ACIRgen") } _ => self.acir_context.less_than_var(lhs, rhs, bit_count), }, diff --git a/compiler/noirc_evaluator/src/acir/tests/mod.rs b/compiler/noirc_evaluator/src/acir/tests/mod.rs index 7d639e4f2ab..cf5390b3276 100644 --- a/compiler/noirc_evaluator/src/acir/tests/mod.rs +++ b/compiler/noirc_evaluator/src/acir/tests/mod.rs @@ -243,33 +243,6 @@ fn derive_pedersen_generators_requires_constant_input() { .expect_err("Should fail with assert constant"); } -#[test] -// Regression for https://github.com/noir-lang/noir/issues/9847 -fn signed_div_overflow() { - // Test that check -128 / -1 overflow for i8 - let src = r#" - acir(inline) predicate_pure fn main f0 { - b0(v1: i8, v2: i8): - v3 = div v1, v2 - return - } - "#; - - let ssa = Ssa::from_str(src).unwrap(); - let inputs = vec![FieldElement::from(128_u128), FieldElement::from(255_u128)]; - let inputs = inputs - .into_iter() - .enumerate() - .map(|(i, f)| (Witness(i as u32), f)) - .collect::>(); - let initial_witness = WitnessMap::from(inputs); - let output = None; - - // acir execution should fail to divide -128 / -1 - let acir_execution_result = execute_ssa(ssa, initial_witness.clone(), output.as_ref()); - assert!(matches!(acir_execution_result, (ACVMStatus::Failure(_), _))); -} - /// Convert the SSA input into ACIR and use ACVM to execute it /// Returns the ACVM execution status and the value of the 'output' witness value, /// unless the provided output is None or the ACVM fails during execution. diff --git a/compiler/noirc_evaluator/src/ssa/mod.rs b/compiler/noirc_evaluator/src/ssa/mod.rs index 703cbb1c624..2689a436da8 100644 --- a/compiler/noirc_evaluator/src/ssa/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/mod.rs @@ -174,6 +174,9 @@ pub fn primary_passes(options: &SsaEvaluatorOptions) -> Vec> { SsaPass::new(Ssa::simplify_cfg, "Simplifying"), SsaPass::new(Ssa::mem2reg, "Mem2Reg"), SsaPass::new(Ssa::remove_bit_shifts, "Removing Bit Shifts"), + // Expand signed lt/div/mod after "Removing Bit Shifts" because that pass might + // introduce signed divisions. + SsaPass::new(Ssa::expand_signed_math, "Expand signed math"), SsaPass::new(Ssa::simplify_cfg, "Simplifying"), SsaPass::new(Ssa::flatten_cfg, "Flattening"), // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores diff --git a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs index 7022eeb7063..2e6bea12bf9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs @@ -1,3 +1,9 @@ +/// An SSA pass that transforms the checked signed arithmetic operations add, sub and mul +/// into unchecked operations followed by explicit overflow checks. +/// +/// The purpose of this pass is to avoid ACIR and Brillig having to handle checked signed arithmetic +/// operations, while also allowing further optimizations to be done during subsequent +/// SSA passes on the expanded instructions. use acvm::{FieldElement, acir::AcirField}; use crate::ssa::{ @@ -29,7 +35,7 @@ impl Function { /// The structure of this pass is simple: /// Go through each block and re-insert all instructions, decomposing any checked signed arithmetic to have explicit /// overflow checks. - pub(crate) fn expand_signed_checks(&mut self) { + fn expand_signed_checks(&mut self) { // TODO: consider whether we can implement this more efficiently in brillig. self.simple_optimization(|context| { diff --git a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs new file mode 100644 index 00000000000..5d80638582f --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs @@ -0,0 +1,552 @@ +/// An SSA pass for ACIR functions that transforms "less than", "div" and "mod" operation on +/// signed integers into equivalent sequences of operations that rely on unsigned integers. +/// +/// The purpose of this pass is to avoid ACIR having to handle signed integers "less than", +/// "div" and "mod" operations (for simplicity), while also allowing further optimizations to +/// be done during subsequent SSA passes on the expanded instructions. +use acvm::FieldElement; + +use crate::ssa::{ + ir::{ + function::Function, + instruction::{Binary, BinaryOp, ConstrainError, Instruction}, + types::NumericType, + value::ValueId, + }, + ssa_gen::Ssa, +}; + +use super::simple_optimization::SimpleOptimizationContext; + +impl Ssa { + /// Expands signed "less than", "div" and "mod" operations in ACIR to be done using + /// unsigned operations. + /// + /// See [`expand_signed_math`][self] module for more information. + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn expand_signed_math(mut self) -> Ssa { + for function in self.functions.values_mut() { + function.expand_signed_math(); + } + self + } +} + +impl Function { + /// The structure of this pass is simple: + /// Go through each block and re-insert all instructions, decomposing any signed + /// "less than", "div" and "mod" operations to be done using unsigned types, but only if this + /// is an ACIR function. + fn expand_signed_math(&mut self) { + if !self.dfg.runtime().is_acir() { + return; + } + + self.simple_optimization(|context| { + let instruction_id = context.instruction_id; + let instruction = context.instruction(); + + // We only care about "less than" + let Instruction::Binary(Binary { + lhs, + rhs, + operator: operator @ (BinaryOp::Lt | BinaryOp::Div | BinaryOp::Mod), + }) = instruction + else { + return; + }; + + // ... and it must be a signed integer operation. + if !context.dfg.type_of_value(*lhs).is_signed() { + return; + } + + let lhs = *lhs; + let rhs = *rhs; + let operator = *operator; + + // We remove the current instruction, as we will need to replace it with multiple new instructions. + context.remove_current_instruction(); + + let [old_result] = context.dfg.instruction_result(instruction_id); + + let mut expansion_context = Context { context }; + let new_result = match operator { + BinaryOp::Lt => expansion_context.insert_lt(lhs, rhs), + BinaryOp::Div => expansion_context.insert_div(lhs, rhs), + BinaryOp::Mod => expansion_context.insert_mod(lhs, rhs), + _ => unreachable!("ICE: expand_signed_math called on non-lt/div/mod"), + }; + + context.replace_value(old_result, new_result); + }); + + #[cfg(debug_assertions)] + expand_signed_math_post_check(self); + } +} + +struct Context<'m, 'dfg, 'mapping> { + context: &'m mut SimpleOptimizationContext<'dfg, 'mapping>, +} + +impl Context<'_, '_, '_> { + fn insert_lt(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + // First cast lhs and rhs to their unsigned equivalents + let bit_size = self.context.dfg.type_of_value(lhs).bit_size(); + let unsigned_typ = NumericType::unsigned(bit_size); + let lhs_unsigned = self.insert_cast(lhs, unsigned_typ); + let rhs_unsigned = self.insert_cast(rhs, unsigned_typ); + + // Check if lhs and rhs are positive or negative, respectively. + // Values greater than or equal to 2^(bit_size-1) are negative so dividing by that would + // give 0 (positive) or 1 (negative). + let first_negative_value = self.numeric_constant(1_u128 << (bit_size - 1), unsigned_typ); + let lhs_is_negative = self.insert_binary(lhs_unsigned, BinaryOp::Div, first_negative_value); + let lhs_is_negative = self.insert_cast(lhs_is_negative, NumericType::bool()); + let rhs_is_negative = self.insert_binary(rhs_unsigned, BinaryOp::Div, first_negative_value); + let rhs_is_negative = self.insert_cast(rhs_is_negative, NumericType::bool()); + + // Do rhs and lhs have a different sign? + let different_sign = self.insert_binary(lhs_is_negative, BinaryOp::Xor, rhs_is_negative); + + // Check lhs < rhs using their unsigned equivalents + let unsigned_lt = self.insert_binary(lhs_unsigned, BinaryOp::Lt, rhs_unsigned); + + // It can be shown that the result is given by xor'ing the two results above: + // - if lhs and rhs have the same sign (different_sign is 0): + // - if both are positive then the unsigned comparison is correct, xor'ing it with 0 gives + // the same result + // - if both are negative then the unsigned comparison is also correct, as, for example, + // for i8, -128 i8 is Field 128 and -1 i8 is Field 255 and `-128 < -1` and `128 < 255` + // - if lhs and rhs have different signs (different_sign is 1): + // - if lhs is positive and rhs is negative then, as fields, rhs will be greater, but + // the result is the opposite (so xor'ing with 1 gives the correct result) + // - if lhs is negative and rhs is positive then, as fields, lhs will be greater, but + // the result is the opposite (so xor'ing with 1 gives the correct result) + self.insert_binary(different_sign, BinaryOp::Xor, unsigned_lt) + } + + fn insert_div(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let is_division = true; + self.insert_div_or_mod(lhs, rhs, is_division) + } + + fn insert_mod(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let is_division = false; + self.insert_div_or_mod(lhs, rhs, is_division) + } + + fn insert_div_or_mod(&mut self, lhs: ValueId, rhs: ValueId, is_division: bool) -> ValueId { + // First cast lhs and rhs to their unsigned equivalents + let bit_size = self.context.dfg.type_of_value(lhs).bit_size(); + let unsigned_typ = NumericType::unsigned(bit_size); + let lhs_unsigned = self.insert_cast(lhs, unsigned_typ); + let rhs_unsigned = self.insert_cast(rhs, unsigned_typ); + + // There's one condition that could generate an overflow: dividing the minimum + // negative value by -1. For example dividing -128 i8 by -1 would give 128, but that + // does not fit i8. So the first thing we do is check for this case. + let min_negative_value = self.numeric_constant(1_u128 << (bit_size - 1), unsigned_typ); + let minus_one = self.numeric_constant((1_u128 << bit_size) - 1, unsigned_typ); + let lhs_is_min_negative_value = + self.insert_binary(lhs_unsigned, BinaryOp::Eq, min_negative_value); + let rhs_is_minus_one = self.insert_binary(rhs_unsigned, BinaryOp::Eq, minus_one); + let min_overflow = self.insert_binary( + lhs_is_min_negative_value, + BinaryOp::Mul { unchecked: true }, + rhs_is_minus_one, + ); + + let zero = self.numeric_constant(0_u128, NumericType::bool()); + let message = if is_division { + "Attempt to divide with overflow".to_string() + } else { + "Attempt to calculate the remainder with overflow".to_string() + }; + self.insert_constrain(min_overflow, zero, Some(message.into())); + + // What about checking that the divisor is not zero? We don't need to explicitly check + // this here because it'll be checked when doing the unsigned div/mod. + + // Check if lhs and rhs are positive or negative, respectively. + // Values greater than or equal to 2^(bit_size-1) are negative so dividing by that would + // give 0 (positive) or 1 (negative). + let lhs_is_negative = self.insert_binary(lhs_unsigned, BinaryOp::Div, min_negative_value); + let rhs_is_negative = self.insert_binary(rhs_unsigned, BinaryOp::Div, min_negative_value); + + // Here we compute the absolute values of lhs and rhs using their 2-complement + let lhs_absolute = + self.two_complement(lhs_unsigned, lhs_is_negative, unsigned_typ, bit_size); + let rhs_absolute = + self.two_complement(rhs_unsigned, rhs_is_negative, unsigned_typ, bit_size); + + // We then perform the division (or modulo) using the absolute values + let operator = if is_division { BinaryOp::Div } else { BinaryOp::Mod }; + let absolute_result = self.insert_binary(lhs_absolute, operator, rhs_absolute); + + let lhs_is_negative = self.insert_cast(lhs_is_negative, NumericType::bool()); + + // The result changes slightly depending on whether we are doing division or modulo. + let result_is_negative = if is_division { + // For division, the result is negative if lhs and rhs have different signs. + let rhs_is_negative = self.insert_cast(rhs_is_negative, NumericType::bool()); + self.insert_binary(lhs_is_negative, BinaryOp::Xor, rhs_is_negative) + } else { + // For modulo, the result has the same sign as lhs + lhs_is_negative + }; + let result_is_negative = self.insert_cast(result_is_negative, unsigned_typ); + + // We return the 2-complement again if lhs and rhs have different signs, with the + // intention of making the result be negative. + let result_unsigned = + self.two_complement(absolute_result, result_is_negative, unsigned_typ, bit_size); + + // If we divide, for example 4 i8 by -5, the absolute division will give 0. + // Because the signs are different, if we do the two complement of 0 we'll get 256, which + // is out of range. Here we take this case into account: if absolute_div is zero the result + // should be zero, otherwise it should be that result. + // Then, we need to multiply result_unsigned by `absolute_div != 0`. + // + // The same is true for modulo: -4 i8 mod 4 is 0, but taking its two-complement would give 256. + let zero = self.numeric_constant(0_u128, unsigned_typ); + let absolute_result_is_zero = self.insert_binary(absolute_result, BinaryOp::Eq, zero); + let absolute_result_is_not_zero = self.insert_not(absolute_result_is_zero); + let absolute_result_is_not_zero = + self.insert_cast(absolute_result_is_not_zero, unsigned_typ); + + let result_unsigned = self.insert_binary( + result_unsigned, + BinaryOp::Mul { unchecked: true }, + absolute_result_is_not_zero, + ); + + // Make sure we return the signed type + self.insert_cast(result_unsigned, NumericType::signed(bit_size)) + } + + /// Returns the 2-complement of `value`, given `value_is_negative` is 1 if the value is negative, + /// and 0 if it's positive. + /// + /// The math here is: + /// + /// result = value + 2*((2^(bit_size - 1) - value)*value_is_negative) + /// + /// For example, for i8 we have bit_size = 8 so: + /// + /// result = value + 2*(128 - value)*value_is_negative + /// + /// If the value is positive, so value_is_negative = 0: + /// + /// result = value + /// + /// That is, the value stays the same. + /// + /// If value_is_negative = 1 we get: + /// + /// result = value + 2*(128 - value) = value + 256 - 2*value = 256 - value + /// + /// which effectively negates the value in 2-complement representation. + fn two_complement( + &mut self, + value: ValueId, + value_is_negative: ValueId, + unsigned_type: NumericType, + bit_size: u32, + ) -> ValueId { + let max_power_of_two = self.numeric_constant(1_u128 << (bit_size - 1), unsigned_type); + + let intermediate = + self.insert_binary(max_power_of_two, BinaryOp::Sub { unchecked: true }, value); + let intermediate = + self.insert_binary(intermediate, BinaryOp::Mul { unchecked: true }, value_is_negative); + let two = self.numeric_constant(2_u128, unsigned_type); + let intermediate = self.insert_binary(intermediate, BinaryOp::Mul { unchecked: true }, two); + self.insert_binary(value, BinaryOp::Add { unchecked: true }, intermediate) + } + + /// Insert a numeric constant into the current function + fn numeric_constant(&mut self, value: impl Into, typ: NumericType) -> ValueId { + self.context.dfg.make_constant(value.into(), typ) + } + + /// Insert a not instruction at the end of the current block. + /// Returns the result of the instruction. + fn insert_not(&mut self, rhs: ValueId) -> ValueId { + self.context.insert_instruction(Instruction::Not(rhs), None).first() + } + + /// Insert a binary instruction at the end of the current block. + /// Returns the result of the binary instruction. + fn insert_binary(&mut self, lhs: ValueId, operator: BinaryOp, rhs: ValueId) -> ValueId { + let instruction = Instruction::Binary(Binary { lhs, rhs, operator }); + self.context.insert_instruction(instruction, None).first() + } + + /// Insert a cast instruction at the end of the current block. + /// Returns the result of the cast instruction. + fn insert_cast(&mut self, value: ValueId, typ: NumericType) -> ValueId { + self.context.insert_instruction(Instruction::Cast(value, typ), None).first() + } + + /// Insert a constrain instruction at the end of the current block. + fn insert_constrain( + &mut self, + lhs: ValueId, + rhs: ValueId, + assert_message: Option, + ) { + self.context.insert_instruction(Instruction::Constrain(lhs, rhs, assert_message), None); + } +} + +/// Post-check condition for [Function::expand_signed_math]. +/// +/// Succeeds if: +/// - `func` does not contain any signed "less than" ops +/// +/// Otherwise panics. +#[cfg(debug_assertions)] +fn expand_signed_math_post_check(func: &Function) { + for block_id in func.reachable_blocks() { + let instruction_ids = func.dfg[block_id].instructions(); + for instruction_id in instruction_ids { + if let Instruction::Binary(binary) = &func.dfg[*instruction_id] { + if func.dfg.type_of_value(binary.lhs).is_signed() { + match binary.operator { + BinaryOp::Lt => { + panic!("Checked signed 'less than' has not been removed") + } + _ => (), + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + assert_ssa_snapshot, + ssa::{ + interpreter::value::{NumericValue, Value}, + opt::assert_ssa_does_not_change, + ssa_gen::Ssa, + }, + }; + + #[test] + fn expands_signed_lt_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = lt v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + + // Check that the expanded code works as expected + let test_cases = [ + (10, 20, true), + (20, 10, false), + (-20, -10, true), + (-10, -20, false), + (-20, 10, true), + (-10, 20, true), + (10, -20, false), + (20, -10, false), + ]; + for (lhs, rhs, expected) in test_cases { + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(lhs)), + Value::Numeric(NumericValue::I8(rhs)), + ]); + assert!(result.is_ok()); + let result = result.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], Value::Numeric(NumericValue::U1(expected))); + } + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + v5 = div v2, u8 128 + v6 = cast v5 as u1 + v7 = div v3, u8 128 + v8 = cast v7 as u1 + v9 = xor v6, v8 + v10 = lt v2, v3 + v11 = xor v9, v10 + return v11 + } + "); + } + + #[test] + fn does_not_expand_signed_lt_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = lt v0, v1 + return v2 + } + "; + assert_ssa_does_not_change(src, Ssa::expand_signed_math); + } + + #[test] + fn expands_signed_div_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = div v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + + // Check that -128 i8 / -1 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(-128)), + Value::Numeric(NumericValue::I8(-1)), + ]); + assert!(result.is_err()); + + // Check that 10 i8 / 0 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(10)), + Value::Numeric(NumericValue::I8(0)), + ]); + assert!(result.is_err()); + + assert_ssa_snapshot!(ssa, @r#" + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + v5 = eq v2, u8 128 + v7 = eq v3, u8 255 + v8 = unchecked_mul v5, v7 + constrain v8 == u1 0, "Attempt to divide with overflow" + v10 = div v2, u8 128 + v11 = div v3, u8 128 + v12 = unchecked_sub u8 128, v2 + v13 = unchecked_mul v12, v10 + v15 = unchecked_mul v13, u8 2 + v16 = unchecked_add v2, v15 + v17 = unchecked_sub u8 128, v3 + v18 = unchecked_mul v17, v11 + v19 = unchecked_mul v18, u8 2 + v20 = unchecked_add v3, v19 + v21 = div v16, v20 + v22 = cast v10 as u1 + v23 = cast v11 as u1 + v24 = xor v22, v23 + v25 = cast v24 as u8 + v26 = unchecked_sub u8 128, v21 + v27 = unchecked_mul v26, v25 + v28 = unchecked_mul v27, u8 2 + v29 = unchecked_add v21, v28 + v31 = eq v21, u8 0 + v32 = not v31 + v33 = cast v32 as u8 + v34 = unchecked_mul v29, v33 + v35 = cast v34 as i8 + return v35 + } + "#); + } + + #[test] + fn does_not_expands_signed_div_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = div v0, v1 + return v2 + } + "; + assert_ssa_does_not_change(src, Ssa::expand_signed_math); + } + + #[test] + fn expands_signed_mod_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = mod v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + + // Check that -128 i8 / -1 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(-128)), + Value::Numeric(NumericValue::I8(-1)), + ]); + assert!(result.is_err()); + + // Check that 10 i8 / 0 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(10)), + Value::Numeric(NumericValue::I8(0)), + ]); + assert!(result.is_err()); + + assert_ssa_snapshot!(ssa, @r#" + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + v5 = eq v2, u8 128 + v7 = eq v3, u8 255 + v8 = unchecked_mul v5, v7 + constrain v8 == u1 0, "Attempt to calculate the remainder with overflow" + v10 = div v2, u8 128 + v11 = div v3, u8 128 + v12 = unchecked_sub u8 128, v2 + v13 = unchecked_mul v12, v10 + v15 = unchecked_mul v13, u8 2 + v16 = unchecked_add v2, v15 + v17 = unchecked_sub u8 128, v3 + v18 = unchecked_mul v17, v11 + v19 = unchecked_mul v18, u8 2 + v20 = unchecked_add v3, v19 + v21 = mod v16, v20 + v22 = cast v10 as u1 + v23 = cast v10 as u8 + v24 = unchecked_sub u8 128, v21 + v25 = unchecked_mul v24, v23 + v26 = unchecked_mul v25, u8 2 + v27 = unchecked_add v21, v26 + v29 = eq v21, u8 0 + v30 = not v29 + v31 = cast v30 as u8 + v32 = unchecked_mul v27, v31 + v33 = cast v32 as i8 + return v33 + } + "#); + } + + #[test] + fn does_not_expands_signed_mod_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = mod v0, v1 + return v2 + } + "; + assert_ssa_does_not_change(src, Ssa::expand_signed_math); + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 21cf15d4876..9946e6f3d91 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -16,6 +16,7 @@ mod defunctionalize; mod die; mod evaluate_static_assert_and_assert_constant; mod expand_signed_checks; +mod expand_signed_math; pub(crate) mod flatten_cfg; mod hint; mod inline_simple_functions; diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index df1851b76f7..a94afeafa12 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -252,10 +252,15 @@ impl Context<'_, '_, '_> { } NumericType::Signed { bit_size } => { // Get the sign of the operand; positive signed operand will just do a division as well - let zero = - self.numeric_constant(FieldElement::zero(), NumericType::signed(bit_size)); + let unsigned_typ = NumericType::unsigned(bit_size); + let lhs_as_unsigned = self.insert_cast(lhs, unsigned_typ); + // The sign will be 0 for positive numbers and 1 for negatives, so it covers both cases. - let lhs_sign = self.insert_binary(lhs, BinaryOp::Lt, zero); + // To compute this we check if the value, as a Field, is greater or equal than the maximum + // value that is considered positive, that is, 2^(bit_size-1)-1: 2^(bit_size-1)-1 < lhs_as_field + let max_positive = (1_u128 << (bit_size - 1)) - 1; + let max_positive = self.numeric_constant(max_positive, unsigned_typ); + let lhs_sign = self.insert_binary(max_positive, BinaryOp::Lt, lhs_as_unsigned); let lhs_sign_as_field = self.insert_cast(lhs_sign, NumericType::NativeField); let lhs_as_field = self.insert_cast(lhs, NumericType::NativeField); // For negative numbers, we prepare for the division using a wrapping addition of a + 1. Unchecked add as these are fields. @@ -914,17 +919,18 @@ mod tests { assert_ssa_snapshot!(ssa, @r" acir(inline) fn main f0 { b0(v0: i32): - v2 = lt v0, i32 0 - v3 = cast v2 as Field - v4 = cast v0 as Field - v5 = add v3, v4 - v6 = truncate v5 to 32 bits, max_bit_size: 33 - v7 = cast v6 as i32 - v9 = div v7, i32 4 - v10 = cast v2 as i32 - v11 = unchecked_sub v9, v10 - v12 = truncate v11 to 32 bits, max_bit_size: 33 - return v12 + v1 = cast v0 as u32 + v3 = lt u32 2147483647, v1 + v4 = cast v3 as Field + v5 = cast v0 as Field + v6 = add v4, v5 + v7 = truncate v6 to 32 bits, max_bit_size: 33 + v8 = cast v7 as i32 + v10 = div v8, i32 4 + v11 = cast v3 as i32 + v12 = unchecked_sub v10, v11 + v13 = truncate v12 to 32 bits, max_bit_size: 33 + return v13 } "); } @@ -992,17 +998,18 @@ mod tests { v55 = mul v54, v50 v56 = add v53, v55 v57 = cast v56 as i32 - v59 = lt v0, i32 0 - v60 = cast v59 as Field - v61 = cast v0 as Field - v62 = add v60, v61 - v63 = truncate v62 to 32 bits, max_bit_size: 33 - v64 = cast v63 as i32 - v65 = div v64, v57 - v66 = cast v59 as i32 - v67 = unchecked_sub v65, v66 - v68 = truncate v67 to 32 bits, max_bit_size: 33 - return v68 + v58 = cast v0 as u32 + v60 = lt u32 2147483647, v58 + v61 = cast v60 as Field + v62 = cast v0 as Field + v63 = add v61, v62 + v64 = truncate v63 to 32 bits, max_bit_size: 33 + v65 = cast v64 as i32 + v66 = div v65, v57 + v67 = cast v60 as i32 + v68 = unchecked_sub v66, v67 + v69 = truncate v68 to 32 bits, max_bit_size: 33 + return v69 } "#); }