diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index dae78079133..f8a0e47352a 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -552,10 +552,10 @@ impl std::ops::Index for FunctionBuilder { fn validate_numeric_type(typ: &NumericType) { match &typ { NumericType::Signed { bit_size } => match bit_size { - 8 | 16 | 32 | 64 | 128 => (), + 8 | 16 | 32 | 64 => (), _ => { panic!( - "Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, 64 or 128." + "Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, or 64." ); } }, diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs index 11b849cc5fa..ba62c98932b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs @@ -24,7 +24,11 @@ pub(super) fn simplify_cast( if let Value::Instruction { instruction, .. } = &dfg[value] { if let Instruction::Cast(original_value, _) = &dfg[*instruction] { - return SimplifiedToInstruction(Instruction::Cast(*original_value, dst_typ)); + let original_value = *original_value; + return match simplify_cast(original_value, dst_typ, dfg) { + None => SimplifiedToInstruction(Instruction::Cast(original_value, dst_typ)), + simpler => simpler, + }; } } @@ -151,4 +155,31 @@ mod tests { } "); } + + #[test] + fn simplifies_out_casting_there_and_back() { + // Casting from e.g. i8 to u64 used to go through sign extending to i64, + // which itself first cast to u8, then u64 to do some arithmetic, then + // the result was cast to i64 and back to u64. + let src = " + acir(inline) fn main f0 { + b0(v0: u64, v1: u64): + v2 = unchecked_add v0, v1 + v3 = cast v2 as i64 + v4 = cast v3 as u64 + return v4 + } + "; + + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u64, v1: u64): + v2 = unchecked_add v0, v1 + v3 = cast v2 as i64 + return v2 + } + "); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs index 2e6bea12bf9..6e9762f046c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs @@ -505,8 +505,7 @@ mod tests { v12 = eq v11, v8 v13 = unchecked_mul v12, v10 constrain v13 == v10, "attempt to add with overflow" - v14 = cast v3 as i32 - return v14 + return v3 } "#); } @@ -538,8 +537,7 @@ mod tests { v13 = eq v12, v8 v14 = unchecked_mul v13, v11 constrain v14 == v11, "attempt to subtract with overflow" - v15 = cast v3 as i32 - return v15 + return v3 } "#); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs index 587467c505e..1b05ea65e4b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs @@ -519,17 +519,16 @@ mod tests { v20 = unchecked_add v3, v19 v21 = mod v16, v20 v22 = cast v10 as u1 - v23 = cast v10 as u8 - v24 = unchecked_sub u8 128, v21 - v25 = unchecked_mul v24, v23 - v26 = unchecked_mul v25, u8 2 - v27 = unchecked_add v21, v26 - v29 = eq v21, u8 0 - v30 = not v29 - v31 = cast v30 as u8 - v32 = unchecked_mul v27, v31 - v33 = cast v32 as i8 - return v33 + v23 = unchecked_sub u8 128, v21 + v24 = unchecked_mul v23, v10 + v25 = unchecked_mul v24, u8 2 + v26 = unchecked_add v21, v25 + v28 = eq v21, u8 0 + v29 = not v28 + v30 = cast v29 as u8 + v31 = unchecked_mul v26, v30 + v32 = cast v31 as i8 + return v32 } "#); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index a94afeafa12..51200c9b34e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -681,7 +681,7 @@ mod tests { v4 = cast v3 as u64 v6 = lt v4, u64 64 constrain v6 == u1 1, "attempt to bit-shift with overflow" - v8 = cast v3 as Field + v8 = cast v1 as Field v10 = call to_le_bits(v8) -> [u1; 1] v12 = array_get v10, index u32 0 -> u1 v13 = not v12 diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 8f5b8db303c..fde6fac94d8 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -371,7 +371,7 @@ impl<'a> FunctionContext<'a> { /// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`. pub(super) fn insert_safe_cast( &mut self, - mut value: ValueId, + value: ValueId, typ: NumericType, location: Location, ) -> ValueId { @@ -392,68 +392,7 @@ impl<'a> FunctionContext<'a> { } std::cmp::Ordering::Equal => value, std::cmp::Ordering::Greater => { - // If target size is bigger, we do a sign extension: - // When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value - // Sign extension in this case will give `2^t-v`, where `t` is the target bit size - // So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative. - // Casting s-bits signed v0 to t-bits will add the following instructions: - // v1 = cast v0 to 's-bits unsigned' - // v2 = lt v1, 2**(s-1) - // v3 = not(v1) - // v4 = cast v3 to 't-bits unsigned' - // v5 = v3 * (2**t - 2**s) - // v6 = cast v1 to 't-bits unsigned' - // return v6 + v5 - let value_as_unsigned = self.insert_safe_cast( - value, - NumericType::unsigned(*incoming_type_size), - location, - ); - let half_width = self.builder.numeric_constant( - FieldElement::from(2_u128.pow(incoming_type_size - 1)), - NumericType::unsigned(*incoming_type_size), - ); - // value_sign is 1 if the value is positive, 0 otherwise - let value_sign = - self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width); - let max_for_incoming_type_size = if *incoming_type_size == 128 { - u128::MAX - } else { - 2_u128.pow(*incoming_type_size) - 1 - }; - let max_for_target_type_size = if target_type_size == 128 { - u128::MAX - } else { - 2_u128.pow(target_type_size) - 1 - }; - let patch = self.builder.numeric_constant( - FieldElement::from( - max_for_target_type_size - max_for_incoming_type_size, - ), - NumericType::unsigned(target_type_size), - ); - let mut is_negative_predicate = self.builder.insert_not(value_sign); - is_negative_predicate = self.insert_safe_cast( - is_negative_predicate, - NumericType::unsigned(target_type_size), - location, - ); - // multiplication by a boolean cannot overflow - let patch_with_sign_predicate = self.builder.insert_binary( - patch, - BinaryOp::Mul { unchecked: true }, - is_negative_predicate, - ); - let value_as_unsigned = self.builder.insert_cast( - value_as_unsigned, - NumericType::unsigned(target_type_size), - ); - // Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow. - self.builder.insert_binary( - patch_with_sign_predicate, - BinaryOp::Add { unchecked: true }, - value_as_unsigned, - ) + self.sign_extend(value, *incoming_type_size, target_type_size, location) } } } @@ -463,45 +402,55 @@ impl<'a> FunctionContext<'a> { ) => { // If target size is smaller, we do a truncation if target_type_size < *incoming_type_size { - value = - self.builder.insert_truncate(value, target_type_size, *incoming_type_size); + self.builder.insert_truncate(value, target_type_size, *incoming_type_size) + } else { + value } - value } // When casting a signed value to u1 we can truncate then cast ( Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }), NumericType::Unsigned { bit_size: 1 }, ) => self.builder.insert_truncate(value, 1, *incoming_type_size), - // For mixed sign to unsigned or unsigned to sign; - // 1. we cast to the required type using the same signedness - // 2. then we switch the signedness + + // For mixed singed to unsigned: ( Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }), NumericType::Unsigned { bit_size: target_type_size }, ) => { - if *incoming_type_size != target_type_size { - value = self.insert_safe_cast( - value, - NumericType::signed(target_type_size), - location, - ); + // when going from lower to higher bit size: + // 1. we sign-extend to the target bits + // 2. we are already in the target signedness + if *incoming_type_size < target_type_size { + // By not the casting to a signed type with the target bit size, we avoid potentially going + // through i128, which is not a type we support in the frontend, and would be strange in SSA. + self.sign_extend(value, *incoming_type_size, target_type_size, location) + } + // when the target bit size is not higher than the source: + // 1. we cast to the required type using the same signedness + // 2. then we switch the signedness + else if *incoming_type_size != target_type_size { + self.insert_safe_cast(value, NumericType::signed(target_type_size), location) + } else { + value } - value } + + // For mixed unsigned to signed: + // 1. we cast to the required type using the same signedness + // 2. then we switch the signedness ( Type::Numeric(NumericType::Unsigned { bit_size: incoming_type_size }), NumericType::Signed { bit_size: target_type_size }, ) => { if *incoming_type_size != target_type_size { - value = self.insert_safe_cast( - value, - NumericType::unsigned(target_type_size), - location, - ); + self.insert_safe_cast(value, NumericType::unsigned(target_type_size), location) + } else { + value } - value } + + // Field to signed/unsigned: ( Type::Numeric(NumericType::NativeField), NumericType::Unsigned { bit_size: target_type_size }, @@ -517,6 +466,70 @@ impl<'a> FunctionContext<'a> { self.builder.insert_cast(result, typ) } + /// During casting signed values, if target size is bigger, we do a sign extension: + /// + /// When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value. + /// Sign extension in this case will give `2^t-v`, where `t` is the target bit size. + /// So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative. + /// + /// Casting s-bits signed v0 to t-bits will add the following instructions: + /// ```ssa + /// v1 = cast v0 to 's-bits unsigned' + /// v2 = lt v1, 2**(s-1) + /// v3 = not(v1) + /// v4 = cast v3 to 't-bits unsigned' + /// v5 = v3 * (2**t - 2**s) + /// v6 = cast v1 to 't-bits unsigned' + /// return v6 + v5 + /// ``` + /// + /// Return an unsigned value that we can cast back to the signed type if we want, + /// or keep it as it is, if we did the sign extension as part of casting e.g. `i8` to `u64`. + fn sign_extend( + &mut self, + value: ValueId, + incoming_type_size: u32, + target_type_size: u32, + location: Location, + ) -> ValueId { + let value_as_unsigned = + self.insert_safe_cast(value, NumericType::unsigned(incoming_type_size), location); + let half_width = self.builder.numeric_constant( + FieldElement::from(2_u128.pow(incoming_type_size - 1)), + NumericType::unsigned(incoming_type_size), + ); + // value_sign is 1 if the value is positive, 0 otherwise + let value_sign = self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width); + let max_for_incoming_type_size = + if incoming_type_size == 128 { u128::MAX } else { 2_u128.pow(incoming_type_size) - 1 }; + let max_for_target_type_size = + if target_type_size == 128 { u128::MAX } else { 2_u128.pow(target_type_size) - 1 }; + let patch = self.builder.numeric_constant( + FieldElement::from(max_for_target_type_size - max_for_incoming_type_size), + NumericType::unsigned(target_type_size), + ); + let mut is_negative_predicate = self.builder.insert_not(value_sign); + is_negative_predicate = self.insert_safe_cast( + is_negative_predicate, + NumericType::unsigned(target_type_size), + location, + ); + // multiplication by a boolean cannot overflow + let patch_with_sign_predicate = self.builder.insert_binary( + patch, + BinaryOp::Mul { unchecked: true }, + is_negative_predicate, + ); + let value_as_unsigned = + self.builder.insert_cast(value_as_unsigned, NumericType::unsigned(target_type_size)); + // Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow. + self.builder.insert_binary( + patch_with_sign_predicate, + BinaryOp::Add { unchecked: true }, + value_as_unsigned, + ) + } + /// Create a const offset of an address for an array load or store pub(super) fn make_offset( &mut self,