Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,10 @@ impl std::ops::Index<BasicBlockId> for FunctionBuilder {
fn validate_numeric_type(typ: &NumericType) {
match &typ {
NumericType::Signed { bit_size } => match bit_size {
8 | 16 | 32 | 64 | 128 => (),
8 | 16 | 32 | 64 => (),
_ => {
panic!(
"Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, 64 or 128."
"Invalid bit size for signed numeric type: {bit_size}. Expected one of 8, 16, 32, or 64."
);
}
},
Expand Down
33 changes: 32 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ pub(super) fn simplify_cast(

if let Value::Instruction { instruction, .. } = &dfg[value] {
if let Instruction::Cast(original_value, _) = &dfg[*instruction] {
return SimplifiedToInstruction(Instruction::Cast(*original_value, dst_typ));
let original_value = *original_value;
return match simplify_cast(original_value, dst_typ, dfg) {
None => SimplifiedToInstruction(Instruction::Cast(original_value, dst_typ)),
simpler => simpler,
};
}
}

Expand Down Expand Up @@ -151,4 +155,31 @@ mod tests {
}
");
}

#[test]
fn simplifies_out_casting_there_and_back() {
// Casting from e.g. i8 to u64 used to go through sign extending to i64,
// which itself first cast to u8, then u64 to do some arithmetic, then
// the result was cast to i64 and back to u64.
let src = "
acir(inline) fn main f0 {
b0(v0: u64, v1: u64):
v2 = unchecked_add v0, v1
v3 = cast v2 as i64
v4 = cast v3 as u64
return v4
}
";

let ssa = Ssa::from_str_simplifying(src).unwrap();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u64, v1: u64):
v2 = unchecked_add v0, v1
v3 = cast v2 as i64
return v2
}
");
}
}
6 changes: 2 additions & 4 deletions compiler/noirc_evaluator/src/ssa/opt/expand_signed_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,7 @@ mod tests {
v12 = eq v11, v8
v13 = unchecked_mul v12, v10
constrain v13 == v10, "attempt to add with overflow"
v14 = cast v3 as i32
return v14
return v3
}
"#);
}
Expand Down Expand Up @@ -538,8 +537,7 @@ mod tests {
v13 = eq v12, v8
v14 = unchecked_mul v13, v11
constrain v14 == v11, "attempt to subtract with overflow"
v15 = cast v3 as i32
return v15
return v3
}
"#);
}
Expand Down
21 changes: 10 additions & 11 deletions compiler/noirc_evaluator/src/ssa/opt/expand_signed_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,17 +519,16 @@ mod tests {
v20 = unchecked_add v3, v19
v21 = mod v16, v20
v22 = cast v10 as u1
v23 = cast v10 as u8
v24 = unchecked_sub u8 128, v21
v25 = unchecked_mul v24, v23
v26 = unchecked_mul v25, u8 2
v27 = unchecked_add v21, v26
v29 = eq v21, u8 0
v30 = not v29
v31 = cast v30 as u8
v32 = unchecked_mul v27, v31
v33 = cast v32 as i8
return v33
v23 = unchecked_sub u8 128, v21
v24 = unchecked_mul v23, v10
v25 = unchecked_mul v24, u8 2
v26 = unchecked_add v21, v25
v28 = eq v21, u8 0
v29 = not v28
v30 = cast v29 as u8
v31 = unchecked_mul v26, v30
v32 = cast v31 as i8
return v32
}
"#);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ mod tests {
v4 = cast v3 as u64
v6 = lt v4, u64 64
constrain v6 == u1 1, "attempt to bit-shift with overflow"
v8 = cast v3 as Field
v8 = cast v1 as Field
v10 = call to_le_bits(v8) -> [u1; 1]
v12 = array_get v10, index u32 0 -> u1
v13 = not v12
Expand Down
177 changes: 95 additions & 82 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ impl<'a> FunctionContext<'a> {
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
mut value: ValueId,
value: ValueId,
typ: NumericType,
location: Location,
) -> ValueId {
Expand All @@ -392,68 +392,7 @@ impl<'a> FunctionContext<'a> {
}
std::cmp::Ordering::Equal => value,
std::cmp::Ordering::Greater => {
// If target size is bigger, we do a sign extension:
// When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value
// Sign extension in this case will give `2^t-v`, where `t` is the target bit size
// So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative.
// Casting s-bits signed v0 to t-bits will add the following instructions:
// v1 = cast v0 to 's-bits unsigned'
// v2 = lt v1, 2**(s-1)
// v3 = not(v1)
// v4 = cast v3 to 't-bits unsigned'
// v5 = v3 * (2**t - 2**s)
// v6 = cast v1 to 't-bits unsigned'
// return v6 + v5
let value_as_unsigned = self.insert_safe_cast(
value,
NumericType::unsigned(*incoming_type_size),
location,
);
let half_width = self.builder.numeric_constant(
FieldElement::from(2_u128.pow(incoming_type_size - 1)),
NumericType::unsigned(*incoming_type_size),
);
// value_sign is 1 if the value is positive, 0 otherwise
let value_sign =
self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width);
let max_for_incoming_type_size = if *incoming_type_size == 128 {
u128::MAX
} else {
2_u128.pow(*incoming_type_size) - 1
};
let max_for_target_type_size = if target_type_size == 128 {
u128::MAX
} else {
2_u128.pow(target_type_size) - 1
};
let patch = self.builder.numeric_constant(
FieldElement::from(
max_for_target_type_size - max_for_incoming_type_size,
),
NumericType::unsigned(target_type_size),
);
let mut is_negative_predicate = self.builder.insert_not(value_sign);
is_negative_predicate = self.insert_safe_cast(
is_negative_predicate,
NumericType::unsigned(target_type_size),
location,
);
// multiplication by a boolean cannot overflow
let patch_with_sign_predicate = self.builder.insert_binary(
patch,
BinaryOp::Mul { unchecked: true },
is_negative_predicate,
);
let value_as_unsigned = self.builder.insert_cast(
value_as_unsigned,
NumericType::unsigned(target_type_size),
);
// Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow.
self.builder.insert_binary(
patch_with_sign_predicate,
BinaryOp::Add { unchecked: true },
value_as_unsigned,
)
self.sign_extend(value, *incoming_type_size, target_type_size, location)
}
}
}
Expand All @@ -463,45 +402,55 @@ impl<'a> FunctionContext<'a> {
) => {
// If target size is smaller, we do a truncation
if target_type_size < *incoming_type_size {
value =
self.builder.insert_truncate(value, target_type_size, *incoming_type_size);
self.builder.insert_truncate(value, target_type_size, *incoming_type_size)
} else {
value
}
value
}
// When casting a signed value to u1 we can truncate then cast
(
Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }),
NumericType::Unsigned { bit_size: 1 },
) => self.builder.insert_truncate(value, 1, *incoming_type_size),
// For mixed sign to unsigned or unsigned to sign;
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness

// For mixed singed to unsigned:
(
Type::Numeric(NumericType::Signed { bit_size: incoming_type_size }),
NumericType::Unsigned { bit_size: target_type_size },
) => {
if *incoming_type_size != target_type_size {
value = self.insert_safe_cast(
value,
NumericType::signed(target_type_size),
location,
);
// when going from lower to higher bit size:
// 1. we sign-extend to the target bits
// 2. we are already in the target signedness
if *incoming_type_size < target_type_size {
// By not the casting to a signed type with the target bit size, we avoid potentially going
// through i128, which is not a type we support in the frontend, and would be strange in SSA.
self.sign_extend(value, *incoming_type_size, target_type_size, location)
}
// when the target bit size is not higher than the source:
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness
else if *incoming_type_size != target_type_size {
self.insert_safe_cast(value, NumericType::signed(target_type_size), location)
} else {
value
}
value
}

// For mixed unsigned to signed:
// 1. we cast to the required type using the same signedness
// 2. then we switch the signedness
(
Type::Numeric(NumericType::Unsigned { bit_size: incoming_type_size }),
NumericType::Signed { bit_size: target_type_size },
) => {
if *incoming_type_size != target_type_size {
value = self.insert_safe_cast(
value,
NumericType::unsigned(target_type_size),
location,
);
self.insert_safe_cast(value, NumericType::unsigned(target_type_size), location)
} else {
value
}
value
}

// Field to signed/unsigned:
(
Type::Numeric(NumericType::NativeField),
NumericType::Unsigned { bit_size: target_type_size },
Expand All @@ -517,6 +466,70 @@ impl<'a> FunctionContext<'a> {
self.builder.insert_cast(result, typ)
}

/// During casting signed values, if target size is bigger, we do a sign extension:
///
/// When the value is negative, it is represented in 2-complement form; `2^s-v`, where `s` is the incoming bit size and `v` is the absolute value.
/// Sign extension in this case will give `2^t-v`, where `t` is the target bit size.
/// So we simply convert `2^s-v` into `2^t-v` by adding `2^t-2^s` to the value when the value is negative.
///
/// Casting s-bits signed v0 to t-bits will add the following instructions:
/// ```ssa
/// v1 = cast v0 to 's-bits unsigned'
/// v2 = lt v1, 2**(s-1)
/// v3 = not(v1)
/// v4 = cast v3 to 't-bits unsigned'
/// v5 = v3 * (2**t - 2**s)
/// v6 = cast v1 to 't-bits unsigned'
/// return v6 + v5
/// ```
///
/// Return an unsigned value that we can cast back to the signed type if we want,
/// or keep it as it is, if we did the sign extension as part of casting e.g. `i8` to `u64`.
fn sign_extend(
&mut self,
value: ValueId,
incoming_type_size: u32,
target_type_size: u32,
location: Location,
) -> ValueId {
let value_as_unsigned =
self.insert_safe_cast(value, NumericType::unsigned(incoming_type_size), location);
let half_width = self.builder.numeric_constant(
FieldElement::from(2_u128.pow(incoming_type_size - 1)),
NumericType::unsigned(incoming_type_size),
);
// value_sign is 1 if the value is positive, 0 otherwise
let value_sign = self.builder.insert_binary(value_as_unsigned, BinaryOp::Lt, half_width);
let max_for_incoming_type_size =
if incoming_type_size == 128 { u128::MAX } else { 2_u128.pow(incoming_type_size) - 1 };
let max_for_target_type_size =
if target_type_size == 128 { u128::MAX } else { 2_u128.pow(target_type_size) - 1 };
let patch = self.builder.numeric_constant(
FieldElement::from(max_for_target_type_size - max_for_incoming_type_size),
NumericType::unsigned(target_type_size),
);
let mut is_negative_predicate = self.builder.insert_not(value_sign);
is_negative_predicate = self.insert_safe_cast(
is_negative_predicate,
NumericType::unsigned(target_type_size),
location,
);
// multiplication by a boolean cannot overflow
let patch_with_sign_predicate = self.builder.insert_binary(
patch,
BinaryOp::Mul { unchecked: true },
is_negative_predicate,
);
let value_as_unsigned =
self.builder.insert_cast(value_as_unsigned, NumericType::unsigned(target_type_size));
// Patch the bit sign, which gives a `target_type_size` bit size value, so it does not overflow.
self.builder.insert_binary(
patch_with_sign_predicate,
BinaryOp::Add { unchecked: true },
value_as_unsigned,
)
}

/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(
&mut self,
Expand Down
Loading