Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 93 additions & 13 deletions compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use acvm::AcirField as _;
use fxhash::FxHashMap as HashMap;

use crate::ssa::{
ir::{
dfg::DataFlowGraph,
function::Function,
instruction::{Binary, BinaryOp, Instruction},
types::NumericType,
value::{Value, ValueId},
},
ssa_gen::Ssa,
};
Expand All @@ -21,6 +26,8 @@ impl Ssa {

impl Function {
fn checked_to_unchecked(&mut self) {
let mut value_max_num_bits = HashMap::<ValueId, u32>::default();

self.simple_reachable_blocks_optimization(|context| {
let instruction = context.instruction();
let Instruction::Binary(binary) = instruction else {
Expand All @@ -39,31 +46,33 @@ impl Function {
match binary.operator {
BinaryOp::Add { unchecked: false } => {
let bit_size = dfg.type_of_value(lhs).bit_size();
let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits);
let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits);

if dfg.get_value_max_num_bits(lhs) < bit_size
&& dfg.get_value_max_num_bits(rhs) < bit_size
{
if max_lhs_bits < bit_size && max_rhs_bits < bit_size {
// `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
let operator = BinaryOp::Add { unchecked: true };
let binary = Binary { operator, ..*binary };
context.replace_current_instruction_with(Instruction::Binary(binary));
}
}
BinaryOp::Sub { unchecked: false } => {
if dfg.is_constant(lhs)
&& dfg.get_value_max_num_bits(lhs) > dfg.get_value_max_num_bits(rhs)
{
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
let operator = BinaryOp::Sub { unchecked: true };
let binary = Binary { operator, ..*binary };
context.replace_current_instruction_with(Instruction::Binary(binary));
if dfg.is_constant(lhs) {
let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits);
let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits);
if max_lhs_bits > max_rhs_bits {
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
let operator = BinaryOp::Sub { unchecked: true };
let binary = Binary { operator, ..*binary };
context.replace_current_instruction_with(Instruction::Binary(binary));
}
}
}
BinaryOp::Mul { unchecked: false } => {
let bit_size = dfg.type_of_value(lhs).bit_size();
let max_lhs_bits = dfg.get_value_max_num_bits(lhs);
let max_rhs_bits = dfg.get_value_max_num_bits(rhs);
let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits);
let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits);

if bit_size == 1
|| max_lhs_bits + max_rhs_bits <= bit_size
Expand All @@ -83,6 +92,51 @@ impl Function {
}
}

/// The logic here is almost the same as [`DataFlowGraph::get_value_max_num_bits`] except that
/// - it takes into account that the bitsize of multiplying two bools is 1
/// - it recurses by memoizing the results in `value_max_num_bits`
fn get_max_num_bits(
dfg: &DataFlowGraph,
value: ValueId,
value_max_num_bits: &mut HashMap<ValueId, u32>,
) -> u32 {
if let Some(bits) = value_max_num_bits.get(&value) {
return *bits;
}

let value_bit_size = dfg.type_of_value(value).bit_size();

let bits = match dfg[value] {
Value::Instruction { instruction, .. } => {
match dfg[instruction] {
Instruction::Cast(original_value, _) => {
let original_bit_size =
get_max_num_bits(dfg, original_value, value_max_num_bits);
// We might have cast e.g. `u1` to `u8` to be able to do arithmetic,
// in which case we want to recover the original smaller bit size;
// OTOH if we cast down, then we don't need the higher original size.
value_bit_size.min(original_bit_size)
}
Instruction::Binary(Binary { lhs, operator: BinaryOp::Mul { .. }, rhs })
if get_max_num_bits(dfg, lhs, value_max_num_bits) == 1
&& get_max_num_bits(dfg, rhs, value_max_num_bits) == 1 =>
{
// When multiplying two values, if their bitsize is 1 then the result's bitsize will be 1 too
1
}
_ => value_bit_size,
}
}
Value::NumericConstant { constant, .. } => constant.num_bits(),
_ => value_bit_size,
};

assert!(bits <= value_bit_size);
value_max_num_bits.insert(value, bits);

bits
}

#[cfg(test)]
mod tests {
use crate::{
Expand Down Expand Up @@ -223,4 +277,30 @@ mod tests {
let ssa = ssa.checked_to_unchecked();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn checked_to_unchecked_when_multiplying_two_upcasted_bools_to_u32_then_multiplying_again() {
let src = "
acir(inline) fn main f0 {
b0(v0: u1, v1: u1, v2: u32):
v3 = cast v0 as u32
v4 = cast v1 as u32
v5 = mul v3, v4
v6 = mul v2, v5
return v6
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.checked_to_unchecked();
assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0(v0: u1, v1: u1, v2: u32):
v3 = cast v0 as u32
v4 = cast v1 as u32
v5 = unchecked_mul v3, v4
v6 = unchecked_mul v2, v5
return v6
}
");
}
}
Loading