diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 4e1dc9b15a0..96b907073a1 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -49,6 +49,7 @@ pub mod ir; pub(crate) mod opt; pub mod parser; pub mod ssa_gen; +pub(crate) mod validation; #[derive(Debug, Clone)] pub enum SsaLogging { diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 04225faeb48..cc1f620aa15 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -29,6 +29,7 @@ use super::{ }, opt::pure::FunctionPurities, ssa_gen::Ssa, + validation::validate_function, }; /// The per-function context for each ssa function being generated. @@ -548,7 +549,7 @@ impl FunctionBuilder { fn validate_ssa(functions: &[Function]) { for function in functions { - function.assert_valid(); + validate_function(function); } } } diff --git a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs index f5f19ef5f6d..f5cda1eb7e5 100644 --- a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs +++ b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs @@ -214,7 +214,10 @@ fn mul_signed() { acir(inline) fn main f0 { b0(): v0 = mul i64 2, i64 100 - return v0 + v1 = cast v0 as u128 + v2 = truncate v1 to 64 bits, max_bit_size: 128 + v3 = cast v2 as i64 + return v3 } ", ); @@ -237,11 +240,17 @@ fn mul_overflow_unsigned() { #[test] fn mul_overflow_signed() { + // We return v0 as we simply want the output from the mul operation in this test. + // However, the valid SSA signed overflow patterns requires that the appropriate + // casts and truncates follow a signed mul. let value = expect_value( " acir(inline) fn main f0 { b0(): v0 = mul i8 127, i8 2 + v1 = cast v0 as u16 + v2 = truncate v1 to 8 bits, max_bit_size: 16 + v3 = cast v2 as i8 return v0 } ", diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index a147da86b90..a0af3979271 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -9,8 +9,7 @@ use super::{ basic_block::{BasicBlock, BasicBlockId}, function::{FunctionId, RuntimeType}, instruction::{ - Binary, BinaryOp, Instruction, InstructionId, InstructionResultType, Intrinsic, - TerminatorInstruction, + Instruction, InstructionId, InstructionResultType, Intrinsic, TerminatorInstruction, }, integer::IntegerConstant, map::DenseMap, @@ -224,62 +223,11 @@ impl DataFlowGraph { instruction_data: Instruction, ctrl_typevars: Option>, ) -> InstructionId { - self.validate_instruction(&instruction_data); - let id = self.instructions.insert(instruction_data); self.make_instruction_results(id, ctrl_typevars); id } - fn validate_instruction(&self, instruction: &Instruction) { - match instruction { - Instruction::Binary(Binary { lhs, rhs, operator }) => { - let lhs_type = self.type_of_value(*lhs); - let rhs_type = self.type_of_value(*rhs); - match operator { - BinaryOp::Lt => { - if lhs_type != rhs_type { - panic!( - "Left-hand side and right-hand side of `lt` must have the same type" - ); - } - - if matches!(lhs_type, Type::Numeric(NumericType::NativeField)) { - panic!("Cannot use `lt` with field elements"); - } - } - BinaryOp::Shl => { - if !matches!(rhs_type, Type::Numeric(NumericType::Unsigned { bit_size: 8 })) - { - panic!("Right-hand side of `shl` must be u8"); - } - } - BinaryOp::Shr => { - if !matches!(rhs_type, Type::Numeric(NumericType::Unsigned { bit_size: 8 })) - { - panic!("Right-hand side of `shr` must be u8"); - } - } - _ => { - if lhs_type != rhs_type { - panic!( - "Left-hand side and right-hand side of `{}` must have the same type", - operator - ); - } - } - } - } - Instruction::ArrayGet { index, .. } | Instruction::ArraySet { index, .. } => { - let index_type = self.type_of_value(*index); - if !matches!(index_type, Type::Numeric(NumericType::Unsigned { bit_size: 32 })) { - panic!("ArrayGet/ArraySet index must be u32"); - } - } - _ => (), - } - } - /// Check if the function runtime would simply ignore this instruction. pub(crate) fn is_handled_by_runtime(&self, instruction: &Instruction) -> bool { match self.runtime() { @@ -953,10 +901,7 @@ impl std::ops::Index for InsertInstructionResult<'_> { #[cfg(test)] mod tests { use super::DataFlowGraph; - use crate::ssa::{ - ir::{instruction::Instruction, types::Type}, - ssa_gen::Ssa, - }; + use crate::ssa::ir::{instruction::Instruction, types::Type}; #[test] fn make_instruction() { @@ -967,58 +912,4 @@ mod tests { let results = dfg.instruction_results(ins_id); assert_eq!(results.len(), 1); } - - #[test] - #[should_panic(expected = "Cannot use `lt` with field elements")] - fn disallows_comparing_fields_with_lt() { - let src = " - acir(inline) impure fn main f0 { - b0(): - v2 = lt Field 1, Field 2 - return - } - "; - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic( - expected = "Left-hand side and right-hand side of `add` must have the same type" - )] - fn disallows_binary_add_with_different_types() { - let src = " - acir(inline) fn main f0 { - b0(): - v2 = add Field 1, i32 2 - return - } - "; - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic(expected = "Right-hand side of `shr` must be u8")] - fn disallows_shr_with_non_u8() { - let src = " - acir(inline) fn main f0 { - b0(): - v2 = shr u32 1, u16 1 - return - } - "; - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic(expected = "Right-hand side of `shl` must be u8")] - fn disallows_shl_with_non_u8() { - let src = " - acir(inline) fn main f0 { - b0(): - v2 = shl u32 1, u16 1 - return - } - "; - let _ = Ssa::from_str(src); - } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index dd85cfa936f..77bb1191a37 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use super::basic_block::BasicBlockId; use super::dfg::{DataFlowGraph, GlobalsGraph}; -use super::instruction::{BinaryOp, Instruction, TerminatorInstruction}; +use super::instruction::TerminatorInstruction; use super::map::Id; use super::types::{NumericType, Type}; use super::value::{Value, ValueId}; @@ -231,89 +231,6 @@ impl Function { } }) } - - /// Asserts that the [`Function`] is well formed. - /// - /// Panics on malformed functions. - pub(crate) fn assert_valid(&self) { - self.assert_single_return_block(); - self.validate_signed_arithmetic_invariants(); - } - - /// Checks that the function has only one return block. - fn assert_single_return_block(&self) { - let reachable_blocks = self.reachable_blocks(); - - // We assume that all functions have a single block which terminates with a `return` instruction. - let return_blocks: BTreeSet<_> = reachable_blocks - .iter() - .filter(|block| { - // All blocks must have a terminator instruction of some sort. - let terminator = self.dfg[**block].terminator().unwrap_or_else(|| { - panic!("Function {} has no terminator in block {block}", self.id()) - }); - matches!(terminator, TerminatorInstruction::Return { .. }) - }) - .collect(); - if return_blocks.len() > 1 { - panic!("Function {} has multiple return blocks {return_blocks:?}", self.id()) - } - } - - /// Validates that any checked signed add/sub is followed by the expected truncate. - fn validate_signed_arithmetic_invariants(&self) { - // State for tracking the last signed binary addition/subtraction - let mut signed_binary_op = None; - for block in self.reachable_blocks() { - for instruction in self.dfg[block].instructions() { - match &self.dfg[*instruction] { - Instruction::Binary(binary) => { - signed_binary_op = None; - - match binary.operator { - // We are only validating addition/subtraction - BinaryOp::Add { unchecked: false } - | BinaryOp::Sub { unchecked: false } => {} - // Otherwise, move onto the next instruction - _ => continue, - } - - // Assume rhs_type is the same as lhs_type - let lhs_type = self.dfg.type_of_value(binary.lhs); - if let Type::Numeric(NumericType::Signed { bit_size }) = lhs_type { - let results = self.dfg.instruction_results(*instruction); - signed_binary_op = Some((bit_size, results[0])); - } - } - Instruction::Truncate { value, bit_size, max_bit_size } => { - let Some((signed_op_bit_size, signed_op_res)) = signed_binary_op.take() - else { - continue; - }; - assert_eq!( - *bit_size, signed_op_bit_size, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); - assert_eq!( - *max_bit_size, - *bit_size + 1, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); - assert_eq!( - *value, signed_op_res, - "ICE: Correct truncate must follow the result of a checked signed add/sub" - ); - } - _ => { - signed_binary_op = None; - } - } - } - } - if signed_binary_op.is_some() { - panic!("ICE: Truncate must follow the result of a checked signed add/sub"); - } - } } impl Clone for Function { @@ -356,155 +273,3 @@ fn sign_smoke() { signature.params.push(Type::Numeric(NumericType::NativeField)); signature.returns.push(Type::Numeric(NumericType::Unsigned { bit_size: 32 })); } - -#[cfg(test)] -mod validation { - use crate::ssa::ssa_gen::Ssa; - - #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] - fn lone_signed_sub_acir() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - return v2 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] - fn lone_signed_sub_brillig() { - // This matches the test above we just want to make sure it holds in the Brillig runtime as well as ACIR - let src = r" - brillig(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - return v2 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] - fn lone_signed_add_acir() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = add v0, v1 - return v2 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic(expected = "ICE: Truncate must follow the result of a checked signed add/sub")] - fn lone_signed_add_brillig() { - let src = r" - brillig(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = add v0, v1 - return v2 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic( - expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" - )] - fn signed_sub_bad_truncate_bit_size() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - v3 = truncate v2 to 32 bits, max_bit_size: 33 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - #[should_panic( - expected = "ICE: Correct truncate must follow the result of a checked signed add/sub" - )] - fn signed_sub_bad_truncate_max_bit_size() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - v3 = truncate v2 to 16 bits, max_bit_size: 18 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - fn truncate_follows_signed_sub_acir() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - v3 = truncate v2 to 16 bits, max_bit_size: 17 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - fn truncate_follows_signed_sub_brillig() { - let src = r" - brillig(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = sub v0, v1 - v3 = truncate v2 to 16 bits, max_bit_size: 17 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - fn truncate_follows_signed_add_acir() { - let src = r" - acir(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = add v0, v1 - v3 = truncate v2 to 16 bits, max_bit_size: 17 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } - - #[test] - fn truncate_follows_signed_add_brillig() { - let src = r" - brillig(inline) pure fn main f0 { - b0(v0: i16, v1: i16): - v2 = add v0, v1 - v3 = truncate v2 to 16 bits, max_bit_size: 17 - return v3 - } - "; - - let _ = Ssa::from_str(src); - } -} diff --git a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs index 30bf745635e..2a23f92f2a1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs @@ -324,6 +324,8 @@ mod tests { b0(v0: u1, v1: i32): v2 = cast v0 as i32 v3 = mul v2, v1 + v4 = cast v3 as u64 + v6 = truncate v4 to 32 bits, max_bit_size: 64 return v2 } "; diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index b208bdb68fd..005fbc481cc 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -1027,7 +1027,7 @@ mod test { b2(): return b3(): - v6 = mul v0, v1 + v6 = unchecked_mul v0, v1 constrain v6 == i32 6 v8 = unchecked_add v2, i32 1 jmp b1(v8) @@ -1053,7 +1053,7 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: i32, v1: i32): - v3 = mul v0, v1 + v3 = unchecked_mul v0, v1 constrain v3 == i32 6 jmp b1(i32 0) b1(v2: i32): @@ -1093,7 +1093,7 @@ mod test { v9 = unchecked_add v2, i32 1 jmp b1(v9) b6(): - v10 = mul v0, v1 + v10 = unchecked_mul v0, v1 constrain v10 == i32 6 v12 = unchecked_add v3, i32 1 jmp b4(v12) @@ -1110,7 +1110,7 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: i32, v1: i32): - v4 = mul v0, v1 + v4 = unchecked_mul v0, v1 constrain v4 == i32 6 jmp b1(i32 0) b1(v2: i32): @@ -1159,8 +1159,8 @@ mod test { b2(): return b3(): - v6 = mul v0, v1 - v7 = mul v6, v0 + v6 = unchecked_mul v0, v1 + v7 = unchecked_mul v6, v0 v8 = eq v7, i32 12 constrain v7 == i32 12 v9 = unchecked_add v2, i32 1 @@ -1177,8 +1177,8 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: i32, v1: i32): - v3 = mul v0, v1 - v4 = mul v3, v0 + v3 = unchecked_mul v0, v1 + v4 = unchecked_mul v3, v0 v6 = eq v4, i32 12 constrain v4 == i32 12 jmp b1(i32 0) @@ -1446,7 +1446,7 @@ mod test { b2(): return b3(): - v6 = mul v0, v1 + v6 = unchecked_mul v0, v1 constrain v6 == u32 6 v8 = add v2, u32 1 jmp b1(v8) @@ -1459,7 +1459,7 @@ mod test { let expected = " brillig(inline) fn main f0 { b0(v0: u32, v1: u32): - v3 = mul v0, v1 + v3 = unchecked_mul v0, v1 constrain v3 == u32 6 jmp b1(u32 0) b1(v2: u32): diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 857d596e044..d0e9a6eb56f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -1490,9 +1490,9 @@ mod tests { v11 = array_get v0, index v10 -> u64 v12 = add v11, u64 1 v13 = array_set v9, index v10, value v12 - v15 = add v1, {idx_type} 1 + v15 = unchecked_add v1, {idx_type} 1 store v13 at v4 - v16 = add v1, {idx_type} 1 // duplicate + v16 = unchecked_add v1, {idx_type} 1 // duplicate jmp b1(v16) b2(): v8 = load v4 -> [u64; 6] diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 5b4eac11658..e8a78c63ad0 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -557,10 +557,6 @@ impl Translator { // before each print. ssa.normalize_ids(); - for function in ssa.functions.values() { - function.assert_valid(); - } - ssa } diff --git a/compiler/noirc_evaluator/src/ssa/validation/mod.rs b/compiler/noirc_evaluator/src/ssa/validation/mod.rs new file mode 100644 index 00000000000..301c79aa5c1 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/validation/mod.rs @@ -0,0 +1,678 @@ +//! Validator that checks whether a function is well formed. +//! +//! It validates: +//! +//! SSA form +//! +//! - That the function contains exactly one return block. +//! - That every checked signed addition or subtraction instruction is +//! followed by a corresponding truncate instruction with the expected bit sizes. +//! +//! Type checking +//! - Check that the input values of certain instructions matches that instruction's constraint +//! At the moment, only [Instruction::Binary], [Instruction::ArrayGet], and [Instruction::ArraySet] +//! are type checked. +use fxhash::FxHashSet as HashSet; + +use crate::ssa::ir::instruction::TerminatorInstruction; + +use super::ir::{ + function::Function, + instruction::{Binary, BinaryOp, Instruction, InstructionId, Intrinsic}, + types::{NumericType, Type}, + value::{Value, ValueId}, +}; + +/// Aside the function being validated, the validator maintains internal state +/// during instruction visitation to track patterns that span multiple instructions. +struct Validator<'f> { + function: &'f Function, + // State for truncate-after-signed-sub validation + // Stores: Option<(bit_size, result)> + signed_binary_op: Option, +} + +#[derive(Debug)] +enum PendingSignedOverflowOp { + AddOrSub { bit_size: u32, result: ValueId }, + Mul { bit_size: u32, mul_result: ValueId, cast_result: Option }, +} + +impl<'f> Validator<'f> { + fn new(function: &'f Function) -> Self { + Self { function, signed_binary_op: None } + } + + /// Validates that any checked signed add/sub/mul are followed by the appropriate instructions. + /// Signed overflow is many instructions but we validate up to the initial truncate. + /// + /// Expects the following SSA form for signed checked operations: + /// Add/Sub -> Truncate + /// Mul -> Cast -> Truncate + fn validate_signed_op_overflow_pattern(&mut self, instruction: InstructionId) { + let dfg = &self.function.dfg; + match &dfg[instruction] { + Instruction::Binary(binary) => { + // Only reset if we are starting a new tracked op. + // We do not reset on unrelated ops. If we already an op pending, we have an ill formed signed op. + if self.signed_binary_op.is_some() { + panic!("Signed binary operation does not follow overflow pattern"); + } + + // Assumes rhs_type is the same as lhs_type + let lhs_type = dfg.type_of_value(binary.lhs); + let Type::Numeric(NumericType::Signed { bit_size }) = lhs_type else { + return; + }; + + let result = dfg.instruction_results(instruction)[0]; + match binary.operator { + BinaryOp::Mul { unchecked: false } => { + self.signed_binary_op = Some(PendingSignedOverflowOp::Mul { + bit_size, + mul_result: result, + cast_result: None, + }); + } + BinaryOp::Add { unchecked: false } | BinaryOp::Sub { unchecked: false } => { + self.signed_binary_op = + Some(PendingSignedOverflowOp::AddOrSub { bit_size, result }); + } + _ => {} + } + } + Instruction::Truncate { value, bit_size, max_bit_size } => { + // Only a truncate can reset the signed binary op state + match self.signed_binary_op.take() { + Some(PendingSignedOverflowOp::AddOrSub { + bit_size: expected_bit_size, + result, + }) => { + assert_eq!(*bit_size, expected_bit_size); + assert_eq!(*max_bit_size, expected_bit_size + 1); + assert_eq!(*value, result); + } + Some(PendingSignedOverflowOp::Mul { + bit_size: expected_bit_size, + cast_result: Some(cast), + .. + }) => { + assert_eq!(*bit_size, expected_bit_size); + assert_eq!(*max_bit_size, 2 * expected_bit_size); + assert_eq!(*value, cast); + } + Some(PendingSignedOverflowOp::Mul { cast_result: None, .. }) => { + panic!("Truncate not matched to signed overflow pattern"); + } + None => { + // Do nothing as there is no overflow op pending + } + } + } + Instruction::Cast(value, typ) => { + match &mut self.signed_binary_op { + Some(PendingSignedOverflowOp::AddOrSub { .. }) => { + panic!( + "Invalid cast inserted after signed checked Add/Sub. It must be followed immediately by truncate" + ); + } + Some(PendingSignedOverflowOp::Mul { + bit_size: expected_bit_size, + mul_result, + cast_result, + }) => { + assert_eq!(typ.bit_size(), 2 * *expected_bit_size); + assert_eq!(*value, *mul_result); + *cast_result = Some(dfg.instruction_results(instruction)[0]); + } + None => { + // Do nothing as there is no overflow op pending + } + } + } + _ => { + if self.signed_binary_op.is_some() { + panic!("Signed binary operation does not follow overflow pattern"); + } + } + } + } + + // Validates there is exactly one return block + fn validate_single_return_block(&self) { + let reachable_blocks = self.function.reachable_blocks(); + + let return_blocks: HashSet<_> = reachable_blocks + .iter() + .filter(|block| { + let terminator = self.function.dfg[**block].terminator().unwrap_or_else(|| { + panic!("Function {} has no terminator in block {block}", self.function.id()) + }); + matches!(terminator, TerminatorInstruction::Return { .. }) + }) + .collect(); + + if return_blocks.len() > 1 { + panic!("Function {} has multiple return blocks {return_blocks:?}", self.function.id()) + } + } + + /// Validates that the instruction has the expected types associated with the values in each instruction + fn type_check_instruction(&self, instruction: InstructionId) { + let dfg = &self.function.dfg; + match &dfg[instruction] { + Instruction::Binary(Binary { lhs, rhs, operator }) => { + let lhs_type = dfg.type_of_value(*lhs); + let rhs_type = dfg.type_of_value(*rhs); + match operator { + BinaryOp::Lt => { + if lhs_type != rhs_type { + panic!( + "Left-hand side and right-hand side of `lt` must have the same type" + ); + } + + if matches!(lhs_type, Type::Numeric(NumericType::NativeField)) { + panic!("Cannot use `lt` with field elements"); + } + } + BinaryOp::Shl => { + if !matches!(rhs_type, Type::Numeric(NumericType::Unsigned { bit_size: 8 })) + { + panic!("Right-hand side of `shl` must be u8"); + } + } + BinaryOp::Shr => { + if !matches!(rhs_type, Type::Numeric(NumericType::Unsigned { bit_size: 8 })) + { + panic!("Right-hand side of `shr` must be u8"); + } + } + _ => { + if lhs_type != rhs_type { + panic!( + "Left-hand side and right-hand side of `{}` must have the same type", + operator + ); + } + } + } + } + Instruction::ArrayGet { index, .. } | Instruction::ArraySet { index, .. } => { + let index_type = dfg.type_of_value(*index); + if !matches!(index_type, Type::Numeric(NumericType::Unsigned { bit_size: 32 })) { + panic!("ArrayGet/ArraySet index must be u32"); + } + } + Instruction::Call { func, arguments } => { + if let Value::Intrinsic(intrinsic) = &dfg[*func] { + match intrinsic { + Intrinsic::ToRadix(_) => { + assert_eq!(arguments.len(), 2); + + let value_typ = dfg.type_of_value(arguments[0]); + assert!(matches!(value_typ, Type::Numeric(NumericType::NativeField))); + + let radix_typ = dfg.type_of_value(arguments[1]); + assert!(matches!( + radix_typ, + Type::Numeric(NumericType::Unsigned { bit_size: 32 }) + )); + } + Intrinsic::ToBits(_) => { + // Intrinsic::ToBits always has a set radix + assert_eq!(arguments.len(), 1); + let value_typ = dfg.type_of_value(arguments[0]); + assert!(matches!(value_typ, Type::Numeric(NumericType::NativeField))); + } + _ => {} + } + } + } + _ => (), + } + } + + fn run(&mut self) { + self.validate_single_return_block(); + + for block in self.function.reachable_blocks() { + for instruction in self.function.dfg[block].instructions() { + self.validate_signed_op_overflow_pattern(*instruction); + self.type_check_instruction(*instruction); + } + } + + if self.signed_binary_op.is_some() { + panic!("Signed binary operation does not follow overflow pattern"); + } + } +} + +/// Validates that the [Function] is well formed. +/// +/// Panics on malformed functions. +pub(crate) fn validate_function(function: &Function) { + let mut validator = Validator::new(function); + validator.run(); +} + +#[cfg(test)] +mod tests { + use crate::ssa::ssa_gen::Ssa; + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_sub_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_sub_brillig() { + // This matches the test above we just want to make sure it holds in the Brillig runtime as well as ACIR + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_add_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_add_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_sub_bad_truncate_bit_size() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 32 bits, max_bit_size: 33 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_sub_bad_truncate_max_bit_size() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 18 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_sub_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_sub_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = sub v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_add_acir() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn truncate_follows_signed_add_brillig() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 17 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "Invalid cast inserted after signed checked Add/Sub. It must be followed immediately by truncate" + )] + fn cast_and_truncate_follows_signed_add() { + let src = r" + brillig(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = add v0, v1 + v3 = cast v2 as i32 + v4 = truncate v2 to 16 bits, max_bit_size: 17 + return v4 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn signed_mul_followed_by_binary() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: Field): + v1 = truncate v0 to 16 bits, max_bit_size: 254 + v2 = cast v1 as i16 + v3 = mul v2, v2 + v4 = div v3, v2 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn signed_mul_followed_by_cast_and_truncate() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_cast() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v0 as u16 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_cast_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u16 + v3 = truncate v2 to 16 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_truncate_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 32 bits, max_bit_size: 32 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "assertion `left == right` failed")] + fn signed_mul_followed_by_bad_truncate_max_bit_size() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: i16): + v1 = mul v0, v0 + v2 = cast v1 as u32 + v3 = truncate v2 to 16 bits, max_bit_size: 33 + v4 = cast v3 as i16 + return v4 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Signed binary operation does not follow overflow pattern")] + fn lone_signed_mul() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = mul v0, v1 + return v2 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Truncate not matched to signed overflow pattern")] + fn signed_mul_followed_by_truncate_but_no_cast() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16, v1: i16): + v2 = mul v0, v1 + v3 = truncate v2 to 16 bits, max_bit_size: 33 + return v3 + } + "; + + let _ = Ssa::from_str(src); + } + + #[test] + fn lone_truncate() { + let src = r" + acir(inline) pure fn main f0 { + b0(v0: i16): + v1 = truncate v0 to 8 bits, max_bit_size: 8 + return v1 + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Cannot use `lt` with field elements")] + fn disallows_comparing_fields_with_lt() { + let src = " + acir(inline) impure fn main f0 { + b0(): + v2 = lt Field 1, Field 2 + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "Left-hand side and right-hand side of `add` must have the same type" + )] + fn disallows_binary_add_with_different_types() { + let src = " + acir(inline) fn main f0 { + b0(): + v2 = add Field 1, i32 2 + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Right-hand side of `shr` must be u8")] + fn disallows_shr_with_non_u8() { + let src = " + acir(inline) fn main f0 { + b0(): + v2 = shr u32 1, u16 1 + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic(expected = "Right-hand side of `shl` must be u8")] + fn disallows_shl_with_non_u8() { + let src = " + acir(inline) fn main f0 { + b0(): + v2 = shl u32 1, u16 1 + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "assertion failed: matches!(value_typ, Type::Numeric(NumericType::NativeField))" + )] + fn to_le_radix_on_non_field_value() { + let src = " + brillig(inline) predicate_pure fn main f0 { + b0(): + call f1(u1 1) + return + } + brillig(inline) fn foo f1 { + b0(v0: u1): + v2 = call to_le_radix(v0, u32 256) -> [u7; 1] + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + #[should_panic( + expected = "assertion failed: matches!(value_typ, Type::Numeric(NumericType::NativeField))" + )] + fn to_le_bits_on_non_field_value() { + let src = " + brillig(inline) predicate_pure fn main f0 { + b0(): + call f1(u1 1) + return + } + brillig(inline) fn foo f1 { + b0(v0: u1): + v2 = call to_le_bits(v0) -> [u1; 32] + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn valid_to_le_radix() { + let src = " + brillig(inline) predicate_pure fn main f0 { + b0(v0: Field): + v1 = call to_le_bytes(v0, u32 256) -> [u8; 1] + return + } + "; + let _ = Ssa::from_str(src); + } + + #[test] + fn valid_to_le_bits() { + let src = " + brillig(inline) predicate_pure fn main f0 { + b0(v0: Field): + v1 = call to_le_bits(v0) -> [u1; 32] + return + } + "; + let _ = Ssa::from_str(src); + } +}