diff --git a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs index 222d3da2935..06333fa0174 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs @@ -597,10 +597,8 @@ impl AcirContext { self.euclidean_division_var(lhs, rhs, bit_size, predicate)?; Ok(quotient_var) } - NumericType::Signed { bit_size } => { - let (quotient_var, _remainder_var) = - self.signed_division_var(lhs, rhs, bit_size, predicate)?; - Ok(quotient_var) + NumericType::Signed { .. } => { + unreachable!("Signed division should have been removed before ACIRgen") } } } @@ -1083,96 +1081,6 @@ impl AcirContext { Ok(()) } - // Returns the 2-complement of lhs, using the provided sign bit in 'leading' - // if leading is zero, it returns lhs - // if leading is one, it returns 2^bit_size-lhs - fn two_complement( - &mut self, - lhs: AcirVar, - leading: AcirVar, - max_bit_size: u32, - ) -> Result { - let max_power_of_two = self.add_constant(power_of_two::(max_bit_size - 1)); - - let intermediate = self.sub_var(max_power_of_two, lhs)?; - let intermediate = self.mul_var(intermediate, leading)?; - - self.add_mul_var(lhs, F::from(2_u128), intermediate) - } - - /// Returns the quotient and remainder such that lhs = rhs * quotient + remainder - /// and |remainder| < |rhs| - /// and remainder has the same sign than lhs - /// Note that this is not the euclidean division, where we have instead remainder < |rhs| - fn signed_division_var( - &mut self, - lhs: AcirVar, - rhs: AcirVar, - bit_size: u32, - predicate: AcirVar, - ) -> Result<(AcirVar, AcirVar), RuntimeError> { - // We derive the signed division from the unsigned euclidean division. - // note that this is not euclidean division! - // If `x` is a signed integer, then `sign(x)x >= 0` - // so if `a` and `b` are signed integers, we can do the unsigned division: - // `sign(a)a = q1*sign(b)b + r1` - // => `a = sign(a)sign(b)q1*b + sign(a)r1` - // => `a = qb+r`, with `|r|<|b|` and `a` and `r` have the same sign. - - assert_ne!(bit_size, 0, "signed integer should have at least one bit"); - - // 2^{max_bit size-1} - let max_power_of_two = self.add_constant(power_of_two::(bit_size - 1)); - let zero = self.add_constant(F::zero()); - let one = self.add_constant(F::one()); - - // Get the sign bit of rhs by computing rhs / max_power_of_two - let (rhs_leading, _) = self.euclidean_division_var(rhs, max_power_of_two, bit_size, one)?; - - // Get the sign bit of lhs by computing lhs / max_power_of_two - let (lhs_leading, _) = self.euclidean_division_var(lhs, max_power_of_two, bit_size, one)?; - - // Signed to unsigned: - let unsigned_lhs = self.two_complement(lhs, lhs_leading, bit_size)?; - let unsigned_rhs = self.two_complement(rhs, rhs_leading, bit_size)?; - - // Performs the division using the unsigned values of lhs and rhs - let (q1, r1) = - self.euclidean_division_var(unsigned_lhs, unsigned_rhs, bit_size, predicate)?; - - // Unsigned to signed: derive q and r from q1,r1 and the signs of lhs and rhs - // Quotient sign is lhs sign * rhs sign, whose resulting sign bit is the XOR of the sign bits - let q_sign = self.xor_var(lhs_leading, rhs_leading, AcirType::unsigned(1))?; - let quotient = self.two_complement(q1, q_sign, bit_size)?; - let remainder = self.two_complement(r1, lhs_leading, bit_size)?; - - // Issue #5129 - When q1 is zero and quotient sign is -1, we compute -0=2^{bit_size}, - // which is not valid because we do not wrap integer operations - // Similar case can happen with the remainder. - let q_is_0 = self.eq_var(q1, zero)?; - let q_is_not_0 = self.not_var(q_is_0, AcirType::unsigned(1))?; - let quotient = self.mul_var(quotient, q_is_not_0)?; - let r_is_0 = self.eq_var(r1, zero)?; - let r_is_not_0 = self.not_var(r_is_0, AcirType::unsigned(1))?; - let remainder = self.mul_var(remainder, r_is_not_0)?; - - // The quotient must be a valid signed integer. - // For instance -128/-1 = 128, but 128 is not a valid i8 - // Because it is the only possible overflow that can happen due to signed representation, - // we simply check for this case: quotient is negative, or distinct from 2^{bit_size-1} - // Indeed, negative quotient cannot 'overflow' because the division will not increase its absolute value - let assert_message = - self.generate_assertion_message_payload("Attempt to divide with overflow".to_string()); - let unsigned = self.not_var(q_sign, AcirType::unsigned(1))?; - - // This overflow check must also be under the predicate - let unsigned = self.mul_var(unsigned, predicate)?; - - self.assert_neq_var(quotient, max_power_of_two, unsigned, Some(assert_message))?; - - Ok((quotient, remainder)) - } - /// Returns a variable which is constrained to be `lhs mod rhs` pub(crate) fn modulo_var( &mut self, @@ -1190,8 +1098,8 @@ impl AcirContext { }; let (_, remainder_var) = match numeric_type { - NumericType::Signed { bit_size } => { - self.signed_division_var(lhs, rhs, bit_size, predicate)? + NumericType::Signed { .. } => { + unreachable!("Signed modulo should have been removed before ACIRgen") } _ => self.euclidean_division_var(lhs, rhs, bit_size, predicate)?, }; diff --git a/compiler/noirc_evaluator/src/acir/tests/mod.rs b/compiler/noirc_evaluator/src/acir/tests/mod.rs index 7d639e4f2ab..cf5390b3276 100644 --- a/compiler/noirc_evaluator/src/acir/tests/mod.rs +++ b/compiler/noirc_evaluator/src/acir/tests/mod.rs @@ -243,33 +243,6 @@ fn derive_pedersen_generators_requires_constant_input() { .expect_err("Should fail with assert constant"); } -#[test] -// Regression for https://github.com/noir-lang/noir/issues/9847 -fn signed_div_overflow() { - // Test that check -128 / -1 overflow for i8 - let src = r#" - acir(inline) predicate_pure fn main f0 { - b0(v1: i8, v2: i8): - v3 = div v1, v2 - return - } - "#; - - let ssa = Ssa::from_str(src).unwrap(); - let inputs = vec![FieldElement::from(128_u128), FieldElement::from(255_u128)]; - let inputs = inputs - .into_iter() - .enumerate() - .map(|(i, f)| (Witness(i as u32), f)) - .collect::>(); - let initial_witness = WitnessMap::from(inputs); - let output = None; - - // acir execution should fail to divide -128 / -1 - let acir_execution_result = execute_ssa(ssa, initial_witness.clone(), output.as_ref()); - assert!(matches!(acir_execution_result, (ACVMStatus::Failure(_), _))); -} - /// Convert the SSA input into ACIR and use ACVM to execute it /// Returns the ACVM execution status and the value of the 'output' witness value, /// unless the provided output is None or the ACVM fails during execution. diff --git a/compiler/noirc_evaluator/src/ssa/mod.rs b/compiler/noirc_evaluator/src/ssa/mod.rs index 9e7f75bb8ff..2058d97bc1d 100644 --- a/compiler/noirc_evaluator/src/ssa/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/mod.rs @@ -125,7 +125,6 @@ pub struct ArtifactsAndWarnings(pub Artifacts, pub Vec); pub fn primary_passes(options: &SsaEvaluatorOptions) -> Vec> { vec![ SsaPass::new(Ssa::expand_signed_checks, "expand signed checks"), - SsaPass::new(Ssa::expand_signed_math, "expand signed math"), SsaPass::new(Ssa::remove_unreachable_functions, "Removing Unreachable Functions"), SsaPass::new(Ssa::defunctionalize, "Defunctionalization"), SsaPass::new_try(Ssa::inline_simple_functions, "Inlining simple functions") @@ -175,6 +174,9 @@ pub fn primary_passes(options: &SsaEvaluatorOptions) -> Vec> { SsaPass::new(Ssa::simplify_cfg, "Simplifying"), SsaPass::new(Ssa::mem2reg, "Mem2Reg"), SsaPass::new(Ssa::remove_bit_shifts, "Removing Bit Shifts"), + // Expand signed lt/div/mod after "Removing Bit Shifts" because that pass might + // introduce signed divisions. + SsaPass::new(Ssa::expand_signed_math, "Expand signed math"), SsaPass::new(Ssa::simplify_cfg, "Simplifying"), SsaPass::new(Ssa::flatten_cfg, "Flattening"), // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores diff --git a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs index 6820fa6c357..c02339b9f72 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs @@ -3,7 +3,7 @@ use acvm::FieldElement; use crate::ssa::{ ir::{ function::Function, - instruction::{Binary, BinaryOp, Instruction}, + instruction::{Binary, BinaryOp, ConstrainError, Instruction}, types::NumericType, value::ValueId, }, @@ -25,8 +25,9 @@ impl Ssa { impl Function { /// The structure of this pass is simple: - /// Go through each block and re-insert all instructions, decomposing any checked signed "less than" operations - /// to be done using unsigned types, but only if this is an ACIR function. + /// Go through each block and re-insert all instructions, decomposing any signed + /// "less than", "div" and "mod" operations to be done using unsigned types, but only if this + /// is an ACIR function. fn expand_signed_math(&mut self) { if !self.dfg.runtime().is_acir() { return; @@ -37,7 +38,11 @@ impl Function { let instruction = context.instruction(); // We only care about "less than" - let Instruction::Binary(Binary { lhs, rhs, operator: BinaryOp::Lt }) = instruction + let Instruction::Binary(Binary { + lhs, + rhs, + operator: operator @ (BinaryOp::Lt | BinaryOp::Div | BinaryOp::Mod), + }) = instruction else { return; }; @@ -49,6 +54,7 @@ impl Function { let lhs = *lhs; let rhs = *rhs; + let operator = *operator; // We remove the current instruction, as we will need to replace it with multiple new instructions. context.remove_current_instruction(); @@ -56,7 +62,12 @@ impl Function { let [old_result] = context.dfg.instruction_result(instruction_id); let mut expansion_context = Context { context }; - let new_result = expansion_context.insert_lt(lhs, rhs); + let new_result = match operator { + BinaryOp::Lt => expansion_context.insert_lt(lhs, rhs), + BinaryOp::Div => expansion_context.insert_div(lhs, rhs), + BinaryOp::Mod => expansion_context.insert_mod(lhs, rhs), + _ => unreachable!("ICE: expand_signed_math called on non-lt/div/mod"), + }; context.replace_value(old_result, new_result); }); @@ -107,11 +118,156 @@ impl Context<'_, '_, '_> { self.insert_binary(different_sign, BinaryOp::Xor, unsigned_lt) } + fn insert_div(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let is_division = true; + self.insert_div_or_mod(lhs, rhs, is_division) + } + + fn insert_mod(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + let is_division = false; + self.insert_div_or_mod(lhs, rhs, is_division) + } + + fn insert_div_or_mod(&mut self, lhs: ValueId, rhs: ValueId, is_division: bool) -> ValueId { + // First cast lhs and rhs to their unsigned equivalents + let bit_size = self.context.dfg.type_of_value(lhs).bit_size(); + let unsigned_typ = NumericType::unsigned(bit_size); + let lhs_unsigned = self.insert_cast(lhs, unsigned_typ); + let rhs_unsigned = self.insert_cast(rhs, unsigned_typ); + + // There's one condition that could generate an overflow: dividing the minimum + // negative value by -1. For example dividing -128 i8 by -1 would give 128, but that + // does not fit i8. So the first thing we do is check for this case. + let min_negative_value = self.numeric_constant(1_u128 << (bit_size - 1), unsigned_typ); + let minus_one = self.numeric_constant((1_u128 << bit_size) - 1, unsigned_typ); + let lhs_is_min_negative_value = + self.insert_binary(lhs_unsigned, BinaryOp::Eq, min_negative_value); + let rhs_is_minus_one = self.insert_binary(rhs_unsigned, BinaryOp::Eq, minus_one); + let min_overflow = self.insert_binary( + lhs_is_min_negative_value, + BinaryOp::Mul { unchecked: true }, + rhs_is_minus_one, + ); + + let zero = self.numeric_constant(0_u128, NumericType::bool()); + let message = if is_division { + "Attempt to divide with overflow".to_string() + } else { + "Attempt to calculate the remainder with overflow".to_string() + }; + self.insert_constrain(min_overflow, zero, Some(message.into())); + + // What about checking that the divisor is not zero? We don't need to explicitly check + // this here because it'll be checked when doing the unsigned div/mod. + + // Check if lhs and rhs are positive or negative, respectively. + // Values greater than or equal to 2^(bit_size-1) are negative so dividing by that would + // give 0 (positive) or 1 (negative). + let lhs_is_negative = self.insert_binary(lhs_unsigned, BinaryOp::Div, min_negative_value); + let rhs_is_negative = self.insert_binary(rhs_unsigned, BinaryOp::Div, min_negative_value); + + // Here we compute the absolute values of lhs and rhs using their 2-complement + let lhs_absolute = + self.two_complement(lhs_unsigned, lhs_is_negative, unsigned_typ, bit_size); + let rhs_absolute = + self.two_complement(rhs_unsigned, rhs_is_negative, unsigned_typ, bit_size); + + // We then perform the division (or modulo) using the absolute values + let operator = if is_division { BinaryOp::Div } else { BinaryOp::Mod }; + let absolute_result = self.insert_binary(lhs_absolute, operator, rhs_absolute); + + let lhs_is_negative = self.insert_cast(lhs_is_negative, NumericType::bool()); + + // The result changes slightly depending on whether we are doing division or modulo. + let result_is_negative = if is_division { + // For division, the result is negative if lhs and rhs have different signs. + let rhs_is_negative = self.insert_cast(rhs_is_negative, NumericType::bool()); + self.insert_binary(lhs_is_negative, BinaryOp::Xor, rhs_is_negative) + } else { + // For modulo, the result has the same sign as lhs + lhs_is_negative + }; + let result_is_negative = self.insert_cast(result_is_negative, unsigned_typ); + + // We return the 2-complement again if lhs and rhs have different signs, with the + // intention of making the result be negative. + let result_unsigned = + self.two_complement(absolute_result, result_is_negative, unsigned_typ, bit_size); + + // If we divide, for example 4 i8 by -5, the absolute division will give 0. + // Because the signs are different, if we do the two complement of 0 we'll get 256, which + // is out of range. Here we take this case into account: if absolute_div is zero the result + // should be zero, otherwise it should be that result. + // Then, we need to multiply result_unsigned by `absolute_div != 0`. + // + // The same is true for modulo: -4 i8 mod 4 is 0, but taking its two-complement would give 256. + let zero = self.numeric_constant(0_u128, unsigned_typ); + let absolute_result_is_zero = self.insert_binary(absolute_result, BinaryOp::Eq, zero); + let absolute_result_is_not_zero = self.insert_not(absolute_result_is_zero); + let absolute_result_is_not_zero = + self.insert_cast(absolute_result_is_not_zero, unsigned_typ); + + let result_unsigned = self.insert_binary( + result_unsigned, + BinaryOp::Mul { unchecked: true }, + absolute_result_is_not_zero, + ); + + // Make sure we return the signed type + self.insert_cast(result_unsigned, NumericType::signed(bit_size)) + } + + /// Returns the 2-complement of `value`, given `value_is_negative` is 1 if the value is negative, + /// and 0 if it's positive. + /// + /// The math here is: + /// + /// result = value + 2*((2^(bit_size - 1) - value)*value_is_negative) + /// + /// For example, for i8 we have bit_size = 8 so: + /// + /// result = value + 2*(128 - value)*value_is_negative + /// + /// If the value is positive, so value_is_negative = 0: + /// + /// result = value + /// + /// That is, the value stays the same. + /// + /// If value_is_negative = 1 we get: + /// + /// result = value + 2*(128 - value) = value + 256 - 2*value = 256 - value + /// + /// which effectively negates the value in 2-complement representation. + fn two_complement( + &mut self, + value: ValueId, + value_is_negative: ValueId, + unsigned_type: NumericType, + bit_size: u32, + ) -> ValueId { + let max_power_of_two = self.numeric_constant(1_u128 << (bit_size - 1), unsigned_type); + + let intermediate = + self.insert_binary(max_power_of_two, BinaryOp::Sub { unchecked: true }, value); + let intermediate = + self.insert_binary(intermediate, BinaryOp::Mul { unchecked: true }, value_is_negative); + let two = self.numeric_constant(2_u128, unsigned_type); + let intermediate = self.insert_binary(intermediate, BinaryOp::Mul { unchecked: true }, two); + self.insert_binary(value, BinaryOp::Add { unchecked: true }, intermediate) + } + /// Insert a numeric constant into the current function fn numeric_constant(&mut self, value: impl Into, typ: NumericType) -> ValueId { self.context.dfg.make_constant(value.into(), typ) } + /// Insert a not instruction at the end of the current block. + /// Returns the result of the instruction. + fn insert_not(&mut self, rhs: ValueId) -> ValueId { + self.context.insert_instruction(Instruction::Not(rhs), None).first() + } + /// Insert a binary instruction at the end of the current block. /// Returns the result of the binary instruction. fn insert_binary(&mut self, lhs: ValueId, operator: BinaryOp, rhs: ValueId) -> ValueId { @@ -124,6 +280,16 @@ impl Context<'_, '_, '_> { fn insert_cast(&mut self, value: ValueId, typ: NumericType) -> ValueId { self.context.insert_instruction(Instruction::Cast(value, typ), None).first() } + + /// Insert a constrain instruction at the end of the current block. + fn insert_constrain( + &mut self, + lhs: ValueId, + rhs: ValueId, + assert_message: Option, + ) { + self.context.insert_instruction(Instruction::Constrain(lhs, rhs, assert_message), None); + } } /// Post-check condition for [Function::expand_signed_math]. @@ -155,7 +321,11 @@ fn expand_signed_math_post_check(func: &Function) { mod tests { use crate::{ assert_ssa_snapshot, - ssa::{opt::assert_ssa_does_not_change, ssa_gen::Ssa}, + ssa::{ + interpreter::value::{NumericValue, Value}, + opt::assert_ssa_does_not_change, + ssa_gen::Ssa, + }, }; #[test] @@ -197,4 +367,151 @@ mod tests { "; assert_ssa_does_not_change(src, Ssa::expand_signed_math); } + + #[test] + fn expands_signed_div_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = div v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + assert_ssa_snapshot!(ssa, @r#" + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + v5 = eq v2, u8 128 + v7 = eq v3, u8 255 + v8 = unchecked_mul v5, v7 + constrain v8 == u1 0, "Attempt to divide with overflow" + v10 = div v2, u8 128 + v11 = div v3, u8 128 + v12 = unchecked_sub u8 128, v2 + v13 = unchecked_mul v12, v10 + v15 = unchecked_mul v13, u8 2 + v16 = unchecked_add v2, v15 + v17 = unchecked_sub u8 128, v3 + v18 = unchecked_mul v17, v11 + v19 = unchecked_mul v18, u8 2 + v20 = unchecked_add v3, v19 + v21 = div v16, v20 + v22 = cast v10 as u1 + v23 = cast v11 as u1 + v24 = xor v22, v23 + v25 = cast v24 as u8 + v26 = unchecked_sub u8 128, v21 + v27 = unchecked_mul v26, v25 + v28 = unchecked_mul v27, u8 2 + v29 = unchecked_add v21, v28 + v31 = eq v21, u8 0 + v32 = not v31 + v33 = cast v32 as u8 + v34 = unchecked_mul v29, v33 + v35 = cast v34 as i8 + return v35 + } + "#); + } + + #[test] + fn signed_div_expansion_checks_overflows() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = div v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + + // Check that -128 i8 / -1 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(-128)), + Value::Numeric(NumericValue::I8(-1)), + ]); + assert!(result.is_err()); + + // Check that 10 i8 / 0 i8 overflows + let result = ssa.interpret(vec![ + Value::Numeric(NumericValue::I8(10)), + Value::Numeric(NumericValue::I8(0)), + ]); + assert!(result.is_err()); + } + + #[test] + fn does_not_expands_signed_div_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = div v0, v1 + return v2 + } + "; + assert_ssa_does_not_change(src, Ssa::expand_signed_math); + } + + #[test] + fn expands_signed_mod_in_acir() { + let src = " + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = mod v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.expand_signed_math(); + assert_ssa_snapshot!(ssa, @r#" + acir(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + v5 = eq v2, u8 128 + v7 = eq v3, u8 255 + v8 = unchecked_mul v5, v7 + constrain v8 == u1 0, "Attempt to calculate the remainder with overflow" + v10 = div v2, u8 128 + v11 = div v3, u8 128 + v12 = unchecked_sub u8 128, v2 + v13 = unchecked_mul v12, v10 + v15 = unchecked_mul v13, u8 2 + v16 = unchecked_add v2, v15 + v17 = unchecked_sub u8 128, v3 + v18 = unchecked_mul v17, v11 + v19 = unchecked_mul v18, u8 2 + v20 = unchecked_add v3, v19 + v21 = mod v16, v20 + v22 = cast v10 as u1 + v23 = cast v10 as u8 + v24 = unchecked_sub u8 128, v21 + v25 = unchecked_mul v24, v23 + v26 = unchecked_mul v25, u8 2 + v27 = unchecked_add v21, v26 + v29 = eq v21, u8 0 + v30 = not v29 + v31 = cast v30 as u8 + v32 = unchecked_mul v27, v31 + v33 = cast v32 as i8 + return v33 + } + "#); + } + + #[test] + fn does_not_expands_signed_mod_in_brillig() { + let src = " + brillig(inline) fn main f0 { + b0(v0: i8, v1: i8): + v2 = mod v0, v1 + return v2 + } + "; + assert_ssa_does_not_change(src, Ssa::expand_signed_math); + } }