Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion acvm-repo/acir_field/src/field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl<F: PrimeField> FieldElement<F> {
self.0
}

fn fits_in_u128(&self) -> bool {
pub fn fits_in_u128(&self) -> bool {
self.num_bits() <= 128
}

Expand Down
159 changes: 98 additions & 61 deletions compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! In Brillig an overflow check is automatically performed on unsigned binary operations
//! so this SSA pass has no effect for Brillig functions.
use acvm::{AcirField, FieldElement};
use num_bigint::BigUint;

use crate::ssa::{
ir::{
Expand Down Expand Up @@ -60,6 +61,21 @@ impl Function {
}
}

/// MAX_NON_OVERFLOWING_CONST_ARG is expected to be [p/U],
/// where U=U128::max() and p is the field modulus.
///
/// Then x<=[p/U]<p/U, so x*U<p
static MAX_NON_OVERFLOWING_CONST_ARG: std::sync::LazyLock<u128> = std::sync::LazyLock::new(|| {
let max_non_overflowing_const_arg = u128::try_from(FieldElement::modulus() / u128::MAX)
.expect("expected max_const_value_that_does_not_overflow to fit into a u128");
assert!(BigUint::from(u128::MAX) * max_non_overflowing_const_arg < FieldElement::modulus());
max_non_overflowing_const_arg
});

fn max_non_overflowing_const_arg() -> u128 {
*MAX_NON_OVERFLOWING_CONST_ARG
}

fn check_u128_mul_overflow(
lhs: ValueId,
rhs: ValueId,
Expand All @@ -69,11 +85,17 @@ fn check_u128_mul_overflow(
let lhs_value = dfg.get_numeric_constant(lhs);
let rhs_value = dfg.get_numeric_constant(rhs);

// If we multiply a constant value 2^n by an unknown u128 value we get at most `2^(n+128) - 2`.
// If `n+128` does not overflow the maximum Field element value, there's no need to check for overflow.
let max_const_value_that_does_not_overflow = 1_u128 << (FieldElement::max_num_bits() - 128);
if lhs_value.is_some_and(|value| value.to_u128() < max_const_value_that_does_not_overflow)
|| rhs_value.is_some_and(|value| value.to_u128() < max_const_value_that_does_not_overflow)
assert!(
lhs_value.map(|value| value.fits_in_u128()).unwrap_or(true),
"expected lhs_value to fit in a u128, but found {lhs_value:?}"
);
assert!(
rhs_value.map(|value| value.fits_in_u128()).unwrap_or(true),
"expected rhs_value to fit in a u128, but found {rhs_value:?}"
);

if lhs_value.is_some_and(|value| value.to_u128() <= max_non_overflowing_const_arg())
|| rhs_value.is_some_and(|value| value.to_u128() <= max_non_overflowing_const_arg())
{
return;
}
Expand Down Expand Up @@ -143,55 +165,62 @@ fn check_u128_mul_overflow(
mod tests {
use crate::{
assert_ssa_snapshot,
ssa::{opt::assert_ssa_does_not_change, ssa_gen::Ssa},
ssa::{
opt::{
assert_ssa_does_not_change, check_u128_mul_overflow::max_non_overflowing_const_arg,
},
ssa_gen::Ssa,
},
};

#[test]
fn does_not_insert_check_if_multiplying_lhs_will_not_overflow_field_element() {
// The big value here is 2^254 - 2^128 - 1, which, when multiplied by any u128
// won't overflow a Field element max value.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u128):
v2 = mul u128 85070591730234615865843651857942052863, v0
v2 = mul u128 {}, v0
return
}
";
assert_ssa_does_not_change(src, Ssa::check_u128_mul_overflow);
}}
",
max_non_overflowing_const_arg()
);
assert_ssa_does_not_change(&src, Ssa::check_u128_mul_overflow);
}

#[test]
fn does_not_insert_check_if_multiplying_rhs_will_not_overflow_field_element() {
// The big value here is 2^254 - 2^128 - 1, which, when multiplied by any u128
// won't overflow a Field element max value.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u128):
v2 = mul v0, u128 85070591730234615865843651857942052863
v2 = mul v0, u128 {}
return
}
";
assert_ssa_does_not_change(src, Ssa::check_u128_mul_overflow);
}}
",
max_non_overflowing_const_arg()
);
assert_ssa_does_not_change(&src, Ssa::check_u128_mul_overflow);
}

#[test]
fn inserts_check_for_lhs() {
// The big value here is 2^254 - 2^128, which, when multiplied by any u128
// might overflow a Field element max value.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u128):
v2 = mul v0, u128 85070591730234615865843651857942052864
v2 = mul v0, u128 {}
return
}
";
let ssa = Ssa::from_str(src).unwrap();
}}",
max_non_overflowing_const_arg() + 1
);
let ssa = Ssa::from_str(&src).unwrap();

let ssa = ssa.check_u128_mul_overflow();
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u128):
v2 = mul v0, u128 85070591730234615865843651857942052864
v2 = mul v0, u128 64323764613183177041862057485226039390
v4 = div v0, u128 18446744073709551616
constrain v4 == u128 0, "attempt to multiply with overflow"
return
Expand All @@ -201,22 +230,22 @@ mod tests {

#[test]
fn inserts_check_for_rhs() {
// The big value here is 2^254 - 2^128, which, when multiplied by any u128
// might overflow a Field element max value.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u128):
v2 = mul u128 85070591730234615865843651857942052864, v0
v2 = mul u128 {}, v0
return
}
";
let ssa = Ssa::from_str(src).unwrap();
}}",
max_non_overflowing_const_arg() + 1
);
let ssa = Ssa::from_str(&src).unwrap();

let ssa = ssa.check_u128_mul_overflow();
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u128):
v2 = mul u128 85070591730234615865843651857942052864, v0
v2 = mul u128 64323764613183177041862057485226039390, v0
v4 = div v0, u128 18446744073709551616
constrain v4 == u128 0, "attempt to multiply with overflow"
return
Expand Down Expand Up @@ -251,21 +280,24 @@ mod tests {

#[test]
fn inserts_assertion_failure_if_overflow_is_guaranteed() {
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0():
v2 = mul u128 85070591730234615865843651857942052864, u128 85070591730234615865843651857942052865
v2 = mul u128 {}, u128 {}
return
}
";
let ssa = Ssa::from_str(src).unwrap();
}}",
max_non_overflowing_const_arg() + 1,
max_non_overflowing_const_arg() + 1
);
let ssa = Ssa::from_str(&src).unwrap();

let ssa = ssa.check_u128_mul_overflow();
// The multiplication remains, but it will be later removed by DIE
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0():
v2 = mul u128 85070591730234615865843651857942052864, u128 85070591730234615865843651857942052865
v1 = mul u128 64323764613183177041862057485226039390, u128 64323764613183177041862057485226039390
constrain u128 1 == u128 0, "attempt to multiply with overflow"
return
}
Expand All @@ -287,22 +319,24 @@ mod tests {
#[test]
fn predicate_overflow_on_lhs_potentially_overflowing() {
// This code performs a u128 multiplication that overflows, under a condition.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u128, v1: u1):
enable_side_effects v1
v2 = mul v0, u128 85070591730234615865843651857942052864
v2 = mul v0, u128 {}
return v2
}
";
let ssa = Ssa::from_str(src).unwrap();
}}",
max_non_overflowing_const_arg() + 1
);
let ssa = Ssa::from_str(&src).unwrap();
let ssa = ssa.flatten_cfg().check_u128_mul_overflow();
// Below, the overflow check takes the 'enable_side_effects' value into account
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u128, v1: u1):
enable_side_effects v1
v3 = mul v0, u128 85070591730234615865843651857942052864
v3 = mul v0, u128 64323764613183177041862057485226039390
v5 = div v0, u128 18446744073709551616
v6 = cast v1 as u128
v7 = unchecked_mul v5, v6
Expand All @@ -315,25 +349,28 @@ mod tests {
#[test]
fn predicate_overflow_on_guaranteed_overflow() {
// This code performs a u128 multiplication that overflows, under a condition.
let src = "
acir(inline) fn main f0 {
let src = format!(
"
acir(inline) fn main f0 {{
b0(v0: u1):
jmpif v0 then: b1, else: b2
b1():
v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455
v2 = mul u128 {}, u128 {}
jmp b2()
b2():
return v0
}
";
let ssa = Ssa::from_str(src).unwrap();
}}",
max_non_overflowing_const_arg() + 1,
max_non_overflowing_const_arg() + 1
);
let ssa = Ssa::from_str(&src).unwrap();
let ssa = ssa.flatten_cfg().check_u128_mul_overflow();
// Below, the overflow check takes the 'enable_side_effects' value into account
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u1):
enable_side_effects v0
v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455
v2 = mul u128 64323764613183177041862057485226039390, u128 64323764613183177041862057485226039390
v3 = cast v0 as u128
constrain v0 == u1 0, "attempt to multiply with overflow"
v5 = not v0
Expand Down
Loading