Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 82 additions & 26 deletions compiler/noirc_frontend/src/hir_def/types/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,27 @@ impl Type {
op: BinaryTypeOperator,
rhs: &Type,
) -> Option<Type> {
let Type::InfixExpr(l_lhs, l_op, l_rhs) = lhs.follow_bindings() else {
return None;
};
match lhs.follow_bindings() {
Type::CheckedCast { from, to } => {
// Apply operation directly to `from` while attempting simplification to `to`.
let from = Type::InfixExpr(from, op, Box::new(rhs.clone()));
let to = Self::try_simplify_non_constants_in_lhs(&to, op, rhs)?;
Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) })
}
Type::InfixExpr(l_lhs, l_op, l_rhs) => {
// Note that this is exact, syntactic equality, not unification.
// `rhs` is expected to already be in canonical form.
if l_op.approx_inverse() != Some(op)
|| l_op == BinaryTypeOperator::Division
|| l_rhs.canonicalize_unchecked() != *rhs
{
return None;
}

// Note that this is exact, syntactic equality, not unification.
// `rhs` is expected to already be in canonical form.
if l_op.approx_inverse() != Some(op)
|| l_op == BinaryTypeOperator::Division
|| l_rhs.canonicalize_unchecked() != *rhs
{
return None;
Some(*l_lhs)
}
_ => None,
}

Some(*l_lhs)
}

/// Try to simplify non-constant expressions in the form `N op1 (M op1 N)`
Expand All @@ -219,23 +226,31 @@ impl Type {
op: BinaryTypeOperator,
rhs: &Type,
) -> Option<Type> {
let Type::InfixExpr(r_lhs, r_op, r_rhs) = rhs.follow_bindings() else {
return None;
};
match rhs.follow_bindings() {
Type::CheckedCast { from, to } => {
// Apply operation directly to `from` while attempting simplification to `to`.
let from = Type::InfixExpr(Box::new(lhs.clone()), op, from);
let to = Self::try_simplify_non_constants_in_rhs(lhs, op, &to)?;
Some(Type::CheckedCast { from: Box::new(from), to: Box::new(to) })
}
Type::InfixExpr(r_lhs, r_op, r_rhs) => {
// `N / (M * N)` should be simplified to `1 / M`, but we only handle
// simplifying to `M` in this function.
if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication
{
return None;
}

// `N / (M * N)` should be simplified to `1 / M`, but we only handle
// simplifying to `M` in this function.
if op == BinaryTypeOperator::Division && r_op == BinaryTypeOperator::Multiplication {
return None;
}
// Note that this is exact, syntactic equality, not unification.
// `lhs` is expected to already be in canonical form.
if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() {
return None;
}

// Note that this is exact, syntactic equality, not unification.
// `lhs` is expected to already be in canonical form.
if r_op.inverse() != Some(op) || *lhs != r_rhs.canonicalize_unchecked() {
return None;
Some(*r_lhs)
}
_ => None,
}

Some(*r_lhs)
}

/// Given:
Expand Down Expand Up @@ -360,6 +375,47 @@ mod tests {

use crate::hir_def::types::{BinaryTypeOperator, Kind, Type, TypeVariable, TypeVariableId};

#[test]
fn solves_n_minus_one_plus_one_through_checked_casts() {
// We want to test that the inclusion of a `CheckedCast` won't prevent us from canonicalizing
// the expression `(N - 1) + 1` to `N` if there exists a `CheckedCast` on the `N - 1` term.

let n = Type::NamedGeneric(
TypeVariable::unbound(TypeVariableId(0), Kind::u32()),
std::rc::Rc::new("N".to_owned()),
);
let n_minus_one = Type::InfixExpr(
Box::new(n.clone()),
BinaryTypeOperator::Subtraction,
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
);
let checked_cast_n_minus_one =
Type::CheckedCast { from: Box::new(n_minus_one.clone()), to: Box::new(n_minus_one) };

let n_minus_one_plus_one = Type::InfixExpr(
Box::new(checked_cast_n_minus_one.clone()),
BinaryTypeOperator::Addition,
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
);

let canonicalized_typ = n_minus_one_plus_one.canonicalize();

assert_eq!(n, canonicalized_typ);

// We also want to check that if the `CheckedCast` is on the RHS then we'll still be able to canonicalize
// the expression `1 + (N - 1)` to `N`.

let one_plus_n_minus_one = Type::InfixExpr(
Box::new(Type::Constant(FieldElement::one(), Kind::u32())),
BinaryTypeOperator::Addition,
Box::new(checked_cast_n_minus_one),
);

let canonicalized_typ = one_plus_n_minus_one.canonicalize();

assert_eq!(n, canonicalized_typ);
}

#[test]
fn instantiate_after_canonicalize_smoke_test() {
let field_element_kind = Kind::numeric(Type::FieldElement);
Expand Down
30 changes: 30 additions & 0 deletions compiler/noirc_frontend/src/tests/arithmetic_generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,36 @@ fn arithmetic_generics_canonicalization_deduplication_regression() {
assert_eq!(errors.len(), 0);
}

#[test]
fn checked_casts_do_not_prevent_canonicalization() {
// Regression test for https://github.com/noir-lang/noir/issues/6495
let source = r#"
pub trait Serialize<let N: u32> {
fn serialize(self) -> [Field; N];
}

pub struct Counted<T> {
pub inner: T,
}

pub fn append<T, let N: u32>(array1: [T; N]) -> [T; N + 1] {
[array1[0]; N + 1]
}

impl<T, let N: u32> Serialize<N> for Counted<T>
where
T: Serialize<N - 1>,
{
fn serialize(self) -> [Field; N] {
append(self.inner.serialize())
}
}
"#;
let errors = get_program_errors(source);
println!("{:?}", errors);
assert_eq!(errors.len(), 0);
}

#[test]
fn arithmetic_generics_checked_cast_zeros() {
let source = r#"
Expand Down
2 changes: 2 additions & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
"callsites",
"callstack",
"callstacks",
"canonicalization",
"canonicalize",
"canonicalized",
"canonicalizing",
"castable",
"catmcgee",
"Celo",
Expand Down