Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,17 +603,12 @@ impl DataFlowGraph {
}
}

/// Returns the Value::Array associated with this ValueId if it refers to an array constant.
/// Returns the item values in with this ValueId if it refers to an array constant, along with the type of the array item.
/// Otherwise, this returns None.
pub(crate) fn get_array_constant(&self, value: ValueId) -> Option<(im::Vector<ValueId>, Type)> {
if let Some(instruction) = self.get_local_or_global_instruction(value) {
match instruction {
Instruction::MakeArray { elements, typ } => Some((elements.clone(), typ.clone())),
_ => None,
}
} else {
// Arrays are shared, so cloning them is cheap
None
match self.get_local_or_global_instruction(value)? {
Instruction::MakeArray { elements, typ } => Some((elements.clone(), typ.clone())),
_ => None,
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ fn simplify_slice_push_back(
slice_sizes.insert(set_last_slice_value, slice_size / element_size);
slice_sizes.insert(new_slice, slice_size / element_size);

let mut value_merger = ValueMerger::new(dfg, block, &mut slice_sizes, call_stack);
let mut value_merger = ValueMerger::new(dfg, block, &slice_sizes, call_stack);

let Ok(new_slice) = value_merger.merge_values(
len_not_equals_capacity,
Expand Down
71 changes: 31 additions & 40 deletions compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/value_merger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ pub(crate) struct ValueMerger<'a> {
dfg: &'a mut DataFlowGraph,
block: BasicBlockId,

// Maps SSA array values with a slice type to their size.
// This must be computed before merging values.
slice_sizes: &'a mut HashMap<ValueId, u32>,
/// Maps SSA array values with a slice type to their size.
/// This must be computed before merging values.
slice_sizes: &'a HashMap<ValueId, u32>,

call_stack: CallStackId,
}
Expand All @@ -27,7 +27,7 @@ impl<'a> ValueMerger<'a> {
pub(crate) fn new(
dfg: &'a mut DataFlowGraph,
block: BasicBlockId,
slice_sizes: &'a mut HashMap<ValueId, u32>,
slice_sizes: &'a HashMap<ValueId, u32>,
call_stack: CallStackId,
) -> Self {
ValueMerger { dfg, block, slice_sizes, call_stack }
Expand Down Expand Up @@ -132,7 +132,7 @@ impl<'a> ValueMerger<'a> {

/// Given an if expression that returns an array: `if c { array1 } else { array2 }`,
/// this function will recursively merge array1 and array2 into a single resulting array
/// by creating a new array containing the result of self.merge_values for each element.
/// by creating a new array containing the result of `self.merge_values` for each element.
pub(crate) fn merge_array_values(
&mut self,
typ: Type,
Expand All @@ -144,14 +144,15 @@ impl<'a> ValueMerger<'a> {
let mut merged = im::Vector::new();

let (element_types, len) = match &typ {
Type::Array(elements, len) => (elements, *len),
Type::Array(elements, len) => (elements.as_slice(), *len),
_ => panic!("Expected array type"),
};

let element_count = element_types.len() as u32;

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index =
u128::from(i * element_types.len() as u32 + element_index as u32).into();
let index = u128::from(i * element_count + element_index as u32).into();
let index = self.dfg.make_constant(index, NumericType::length_type());

let typevars = Some(vec![element_type.clone()]);
Expand Down Expand Up @@ -192,72 +193,62 @@ impl<'a> ValueMerger<'a> {
let mut merged = im::Vector::new();

let element_types = match &typ {
Type::Slice(elements) => elements,
Type::Slice(elements) => elements.as_slice(),
_ => panic!("Expected slice type"),
};

let then_len = self.slice_sizes.get(&then_value_id).copied().unwrap_or_else(|| {
let (slice, typ) = self.dfg.get_array_constant(then_value_id).unwrap_or_else(|| {
panic!("ICE: Merging values during flattening encountered slice {then_value_id} without a preset size");
});
(slice.len() / typ.element_types().len()) as u32
panic!("ICE: Merging values during flattening encountered slice {then_value_id} without a preset size");
});

let else_len = self.slice_sizes.get(&else_value_id).copied().unwrap_or_else(|| {
let (slice, typ) = self.dfg.get_array_constant(else_value_id).unwrap_or_else(|| {
panic!("ICE: Merging values during flattening encountered slice {else_value_id} without a preset size");
});
(slice.len() / typ.element_types().len()) as u32
panic!("ICE: Merging values during flattening encountered slice {else_value_id} without a preset size");
});

let len = then_len.max(else_len);
let element_count = element_types.len() as u32;

let flattened_then_length = then_len * element_types.len() as u32;
let flattened_else_length = else_len * element_types.len() as u32;
let flat_then_length = then_len * element_types.len() as u32;
let flat_else_length = else_len * element_types.len() as u32;

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index_u32 = i * element_types.len() as u32 + element_index as u32;
let index_u32 = i * element_count + element_index as u32;
let index_value = u128::from(index_u32).into();
let index = self.dfg.make_constant(index_value, NumericType::length_type());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars, len| {
if len <= index_u32 {
panic!("get_element invoked with an out of bounds index");
} else {
let get = Instruction::ArrayGet { array, index };
let results = self.dfg.insert_instruction_and_results(
get,
self.block,
typevars,
self.call_stack,
);
results.first()
}
assert!(index_u32 < len, "get_element invoked with an out of bounds index");
let get = Instruction::ArrayGet { array, index };
let results = self.dfg.insert_instruction_and_results(
get,
self.block,
typevars,
self.call_stack,
);
results.first()
};

// If it's out of bounds for the "then" slice, a value in the "else" *must* exist.
// We can use that value directly as accessing it is always checked against the actual
// slice length.
if index_u32 >= flattened_then_length {
let else_element = get_element(else_value_id, typevars, flattened_else_length);
if index_u32 >= flat_then_length {
let else_element = get_element(else_value_id, typevars, flat_else_length);
merged.push_back(else_element);
continue;
}

// Same for if it's out of bounds for the "else" slice.
if index_u32 >= flattened_else_length {
let then_element =
get_element(then_value_id, typevars.clone(), flattened_then_length);
if index_u32 >= flat_else_length {
let then_element = get_element(then_value_id, typevars, flat_then_length);
merged.push_back(then_element);
continue;
}

let then_element =
get_element(then_value_id, typevars.clone(), flattened_then_length);
let else_element = get_element(else_value_id, typevars, flattened_else_length);
let then_element = get_element(then_value_id, typevars.clone(), flat_then_length);
let else_element = get_element(else_value_id, typevars, flat_else_length);

merged.push_back(self.merge_values(
then_condition,
Expand Down
19 changes: 12 additions & 7 deletions compiler/noirc_evaluator/src/ssa/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,25 @@ impl Type {
/// The size of a type is defined as representing how many Fields are needed
/// to represent the type. This is 1 for every primitive type, and is the number of fields
/// for any flattened tuple type.
///
/// Panics if `self` is not a [`Type::Array`] or [`Type::Slice`].
pub(crate) fn element_size(&self) -> usize {
match self {
Type::Array(elements, _) | Type::Slice(elements) => elements.len(),
other => panic!("element_size: Expected array or slice, found {other}"),
}
}

/// Return the types of items in this array/slice.
///
/// Panics if `self` is not a [`Type::Array`] or [`Type::Slice`].
pub(crate) fn element_types(self) -> Arc<Vec<Type>> {
match self {
Type::Array(element_types, _) | Type::Slice(element_types) => element_types,
other => panic!("element_types: Expected array or slice, found {other}"),
}
}

pub(crate) fn contains_slice_element(&self) -> bool {
match self {
Type::Array(elements, _) => {
Expand Down Expand Up @@ -269,13 +281,6 @@ impl Type {
}
}

pub(crate) fn element_types(self) -> Arc<Vec<Type>> {
match self {
Type::Array(element_types, _) | Type::Slice(element_types) => element_types,
other => panic!("element_types: Expected array or slice, found {other}"),
}
}

pub(crate) fn first(&self) -> Type {
match self {
Type::Numeric(_) | Type::Function => self.clone(),
Expand Down
62 changes: 42 additions & 20 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//! ```
//!
//! If the shift amount is not a constant, 2^N is computed via square&multiply,
//! using the bits decomposition of exponent.
//! using the bits decomposition of the exponent.
//!
//! Pseudo-code of the computation:
//!
Expand All @@ -33,7 +33,7 @@
//!
//! ## Unsigned shift-left
//!
//! Shifting an unsigned integer to the right by N is the same as multiplying by 2^N.
//! Shifting an unsigned integer to the left by N is the same as multiplying by 2^N.
//! However, since that can overflow the target bit size, the operation is done using
//! Field, then truncated to the target bit size.
//!
Expand Down Expand Up @@ -137,8 +137,8 @@ struct Context<'m, 'dfg, 'mapping> {
}

impl Context<'_, '_, '_> {
/// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs
/// and truncate the result to bit_size
/// Insert SSA instructions which computes lhs << rhs by doing lhs*2^rhs
/// and truncate the result to `bit_size`.
fn insert_wrapping_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let typ = self.context.dfg.type_of_value(lhs).unwrap_numeric();
let max_lhs_bits = self.context.dfg.get_value_max_num_bits(lhs);
Expand Down Expand Up @@ -186,7 +186,7 @@ impl Context<'_, '_, '_> {
let result = self.insert_truncate(result, typ.bit_size(), max_bit);
self.insert_cast(result, typ)
} else {
// Otherwise, the result might not bit in a FieldElement.
// Otherwise, the result might not fit in a FieldElement.
// For this, if we have to do `lhs << rhs` we can first shift by half of `rhs`, truncate,
// then shift by `rhs - half_of_rhs` and truncate again.
assert!(typ.bit_size() <= 128);
Expand All @@ -196,12 +196,12 @@ impl Context<'_, '_, '_> {
// rhs_divided_by_two = rhs / 2
let rhs_divided_by_two = self.insert_binary(rhs, BinaryOp::Div, two);

// rhs_remainder = rhs - rhs_remainder
// rhs_remainder = rhs - rhs_divided_by_two
let rhs_remainder =
self.insert_binary(rhs, BinaryOp::Sub { unchecked: true }, rhs_divided_by_two);

// pow1 = 2^rhs_divided_by_two
// pow2 = r^rhs_remainder
// pow2 = 2^rhs_remainder
let pow1 = self.two_pow(rhs_divided_by_two);
let pow2 = self.two_pow(rhs_remainder);

Expand All @@ -216,9 +216,29 @@ impl Context<'_, '_, '_> {
}
}

/// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs
/// For negative signed integers, we do the division on the 1-complement representation of lhs,
/// before converting back the result to the 2-complement representation.
/// Insert SSA instructions which computes lhs >> rhs by doing lhs/2^rhs
///
/// For negative signed integers, we do the shifting using a technique based on how dividing a
/// 2-complement value can be done by converting to the 1-complement representation of lhs,
/// shifting, then converting back the result to the 2-complement representation.
///
/// To understand the algorithm, take a look at how division works on pages 7-8 of
/// <https://dspace.mit.edu/bitstream/handle/1721.1/6090/AIM-378.pdf>
///
/// Division for a negative number represented as a 2-complement is implemented by the following steps:
/// 1. Convert to 1-complement by subtracting 1 from the value
/// 2. Shift right by the number of bits corresponding to the divisor
/// 3. Convert back to 2-complement by adding 1 to the result
///
/// That's division in terms of shifting; we need shifting in terms of division. The following steps show how:
/// * `DIV(a) = SHR(a-1)+1`
/// * `SHR(a-1) = DIV(a)-1`
/// * `SHR(a) = DIV(a+1)-1`
///
/// Hence we handle negative values in shifting by:
/// 1. Adding 1 to the value
/// 2. Dividing by 2^rhs
/// 3. Subtracting 1 from the result
fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId {
let lhs_typ = self.context.dfg.type_of_value(lhs).unwrap_numeric();

Expand All @@ -234,24 +254,24 @@ impl Context<'_, '_, '_> {
// Get the sign of the operand; positive signed operand will just do a division as well
let zero =
self.numeric_constant(FieldElement::zero(), NumericType::signed(bit_size));
// The sign will be 0 for positive numbers and 1 for negatives, so it covers both cases.
let lhs_sign = self.insert_binary(lhs, BinaryOp::Lt, zero);
let lhs_sign_as_field = self.insert_cast(lhs_sign, NumericType::NativeField);
let lhs_as_field = self.insert_cast(lhs, NumericType::NativeField);
// For negative numbers, convert to 1-complement using wrapping addition of a + 1
// Unchecked add as these are fields
// For negative numbers, we prepare for the division using a wrapping addition of a + 1. Unchecked add as these are fields.
let add = BinaryOp::Add { unchecked: true };
let one_complement = self.insert_binary(lhs_sign_as_field, add, lhs_as_field);
let one_complement = self.insert_truncate(one_complement, bit_size, bit_size + 1);
let one_complement =
self.insert_cast(one_complement, NumericType::signed(bit_size));
// Performs the division on the 1-complement (or the operand if positive)
let shifted_complement = self.insert_binary(one_complement, BinaryOp::Div, pow);
// Convert back to 2-complement representation if operand is negative
let div_complement = self.insert_binary(lhs_sign_as_field, add, lhs_as_field);
let div_complement = self.insert_truncate(div_complement, bit_size, bit_size + 1);
let div_complement =
self.insert_cast(div_complement, NumericType::signed(bit_size));
// Performs the division on the adjusted complement (or the operand if positive)
let shifted_complement = self.insert_binary(div_complement, BinaryOp::Div, pow);
// For negative numbers, convert back to 2-complement by subtracting 1.
let lhs_sign_as_int = self.insert_cast(lhs_sign, lhs_typ);

// The requirements for this to underflow are all of these:
// - lhs < 0
// - ones_complement(lhs) / (2^rhs) == 0
// - div_complement(lhs) / (2^rhs) == 0
// As the upper bit is set for the ones complement of negative numbers we'd need 2^rhs
// to be larger than the lhs bitsize for this to overflow.
let sub = BinaryOp::Sub { unchecked: true };
Expand All @@ -265,13 +285,15 @@ impl Context<'_, '_, '_> {

/// Computes 2^exponent via square&multiply, using the bits decomposition of exponent
/// Pseudo-code of the computation:
/// ```text
/// let mut r = 1;
/// let exponent_bits = to_bits(exponent);
/// for i in 1 .. bit_size + 1 {
/// let r_squared = r * r;
/// let b = exponent_bits[bit_size - i];
/// r = if b { 2 * r_squared } else { r_squared };
/// }
/// ```
fn two_pow(&mut self, exponent: ValueId) -> ValueId {
// Require that exponent < bit_size, ensuring that `pow` returns a value consistent with `lhs`'s type.
let max_bit_size = self.context.dfg.type_of_value(exponent).bit_size();
Expand Down
Loading
Loading