diff --git a/acvm-repo/brillig/src/opcodes.rs b/acvm-repo/brillig/src/opcodes.rs index e7170c56087..70800d0aa6e 100644 --- a/acvm-repo/brillig/src/opcodes.rs +++ b/acvm-repo/brillig/src/opcodes.rs @@ -84,6 +84,23 @@ impl HeapValueType { pub fn field() -> HeapValueType { HeapValueType::Simple(BitSize::Field) } + + pub fn flattened_size(&self) -> Option { + match self { + HeapValueType::Simple(_) => Some(1), + HeapValueType::Array { value_types, size } => { + let element_size = + value_types.iter().map(|t| t.flattened_size()).sum::>(); + + // Multiply element size by number of elements. + element_size.map(|element_size| element_size * size) + } + HeapValueType::Vector { .. } => { + // Vectors are dynamic, so we cannot determine their size statically. + None + } + } + } } /// A fixed-sized array starting from a Brillig memory location. diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 5bdf6a55f78..32a7a8c7895 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -769,11 +769,33 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { )); } + debug_assert_eq!( + destinations.len(), + destination_value_types.len(), + "Number of destinations must match number of value types", + ); + debug_assert_eq!( + destinations.len(), + values.len(), + "Number of foreign call return values must match number of destinations", + ); for ((destination, value_type), output) in destinations.iter().zip(destination_value_types).zip(&values) { match (destination, value_type) { (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => { + let output_fields = output.fields(); + if value_type + .flattened_size() + .is_some_and(|flattened_size| output_fields.len() != flattened_size) + { + return Err(format!( + "Foreign call return value does not match expected size. Expected {} but got {}", + value_type.flattened_size().unwrap(), + output_fields.len(), + )); + } + match output { ForeignCallParam::Single(value) => { self.write_value_to_memory(*value_index, value, *bit_size)?; @@ -789,20 +811,38 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), HeapValueType::Array { value_types, size: type_size }, ) if size == type_size => { + let output_fields = output.fields(); + if value_type + .flattened_size() + .is_some_and(|flattened_size| output_fields.len() != flattened_size) + { + return Err(format!( + "Foreign call return value does not match expected size. Expected {} but got {}", + value_type.flattened_size().unwrap(), + output_fields.len(), + )); + } + if HeapValueType::all_simple(value_types) { match output { ForeignCallParam::Array(values) => { if values.len() != *size { // foreign call returning flattened values into a nested type, so the sizes do not match let destination = self.memory.read_ref(*pointer_index); - let return_type = value_type; + let mut flatten_values_idx = 0; //index of values read from flatten_values self.write_slice_of_values_to_memory( destination, - &output.fields(), + &output_fields, &mut flatten_values_idx, - return_type, + value_type, )?; + // Should be caught earlier but we want to be explicit. + debug_assert_eq!( + flatten_values_idx, + output_fields.len(), + "Not all values were written to memory" + ); } else { self.write_values_to_memory_slice( *pointer_index, @@ -825,10 +865,15 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { let mut flatten_values_idx = 0; //index of values read from flatten_values self.write_slice_of_values_to_memory( destination, - &output.fields(), + &output_fields, &mut flatten_values_idx, return_type, )?; + debug_assert_eq!( + flatten_values_idx, + output_fields.len(), + "Not all values were written to memory" + ); } } ( @@ -841,6 +886,10 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { if HeapValueType::all_simple(value_types) { match output { ForeignCallParam::Array(values) => { + if values.len() % value_types.len() != 0 { + return Err("Returned data does not match vector element size" + .to_string()); + } // Set our size in the size address self.memory.write(*size_index, values.len().into()); self.write_values_to_memory_slice( @@ -868,8 +917,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { } } - let _ = - std::mem::replace(&mut self.foreign_call_results[foreign_call_index].values, values); + self.foreign_call_results[foreign_call_index].values = values; Ok(()) } @@ -2630,4 +2678,171 @@ mod tests { } ); } + + #[test] + fn aborts_when_foreign_call_returns_too_much_data() { + let calldata: Vec = vec![]; + + let opcodes = &[ + Opcode::Const { + destination: MemoryAddress::direct(0), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(1u64), + }, + Opcode::ForeignCall { + function: "foo".to_string(), + destinations: vec![ValueOrArray::HeapArray(HeapArray { + pointer: MemoryAddress::Direct(0), + size: 3, + })], + destination_value_types: vec![HeapValueType::Array { + value_types: vec![HeapValueType::Simple(BitSize::Field)], + size: 3, + }], + inputs: Vec::new(), + input_value_types: Vec::new(), + }, + ]; + let solver = StubbedBlackBoxSolver::default(); + let mut vm = VM::new(calldata, opcodes, &solver, false, None); + + let status = vm.process_opcodes(); + assert_eq!( + status, + VMStatus::ForeignCallWait { function: "foo".to_string(), inputs: Vec::new() } + ); + vm.resolve_foreign_call(ForeignCallResult { + values: vec![ForeignCallParam::Array(vec![ + FieldElement::from(1u128), + FieldElement::from(2u128), + FieldElement::from(3u128), + FieldElement::from(4u128), // Extra value that exceeds the expected size + ])], + }); + + let status = vm.process_opcode(); + assert_eq!( + status, + VMStatus::Failure { + reason: FailureReason::RuntimeError { + message: + "Foreign call return value does not match expected size. Expected 3 but got 4" + .to_string() + }, + call_stack: vec![1] + } + ); + } + + #[test] + fn aborts_when_foreign_call_returns_not_enough_much_data() { + let calldata: Vec = vec![]; + + let opcodes = &[ + Opcode::Const { + destination: MemoryAddress::direct(0), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(1u64), + }, + Opcode::ForeignCall { + function: "foo".to_string(), + destinations: vec![ValueOrArray::HeapArray(HeapArray { + pointer: MemoryAddress::Direct(0), + size: 3, + })], + destination_value_types: vec![HeapValueType::Array { + value_types: vec![HeapValueType::Simple(BitSize::Field)], + size: 3, + }], + inputs: Vec::new(), + input_value_types: Vec::new(), + }, + ]; + let solver = StubbedBlackBoxSolver::default(); + let mut vm = VM::new(calldata, opcodes, &solver, false, None); + + let status = vm.process_opcodes(); + assert_eq!( + status, + VMStatus::ForeignCallWait { function: "foo".to_string(), inputs: Vec::new() } + ); + vm.resolve_foreign_call(ForeignCallResult { + values: vec![ForeignCallParam::Array(vec![ + FieldElement::from(1u128), + FieldElement::from(2u128), + // We're missing a value here + ])], + }); + + let status = vm.process_opcode(); + assert_eq!( + status, + VMStatus::Failure { + reason: FailureReason::RuntimeError { + message: + "Foreign call return value does not match expected size. Expected 3 but got 2" + .to_string() + }, + call_stack: vec![1] + } + ); + } + + #[test] + fn aborts_when_foreign_call_returns_data_which_doesnt_match_vector_elements() { + let calldata: Vec = vec![]; + + let opcodes = &[ + Opcode::Const { + destination: MemoryAddress::direct(0), + bit_size: BitSize::Integer(IntegerBitSize::U32), + value: FieldElement::from(2u64), + }, + Opcode::ForeignCall { + function: "foo".to_string(), + destinations: vec![ValueOrArray::HeapVector(HeapVector { + pointer: MemoryAddress::Direct(0), + size: MemoryAddress::Direct(1), + })], + destination_value_types: vec![HeapValueType::Vector { + value_types: vec![ + HeapValueType::Simple(BitSize::Field), + HeapValueType::Simple(BitSize::Field), + ], + }], + inputs: Vec::new(), + input_value_types: Vec::new(), + }, + ]; + let solver = StubbedBlackBoxSolver::default(); + let mut vm = VM::new(calldata, opcodes, &solver, false, None); + + let status = vm.process_opcodes(); + assert_eq!( + status, + VMStatus::ForeignCallWait { function: "foo".to_string(), inputs: Vec::new() } + ); + + // Here we're returning an array of 3 elements, however the vector expects 2 fields per element + // (see `value_types` above), so the returned data does not match the expected vector element size + vm.resolve_foreign_call(ForeignCallResult { + values: vec![ForeignCallParam::Array(vec![ + FieldElement::from(1u128), + FieldElement::from(2u128), + FieldElement::from(3u128), + // We're missing a value here + ])], + }); + + let status = vm.process_opcode(); + assert_eq!( + status, + VMStatus::Failure { + reason: FailureReason::RuntimeError { + message: "Returned data does not match vector element size".to_string() + }, + call_stack: vec![1] + } + ); + } } diff --git a/test_programs/noir_test_success/regression_4561/src/main.nr b/test_programs/noir_test_success/regression_4561/src/main.nr index d925ed01cd8..afc7cabfb9b 100644 --- a/test_programs/noir_test_success/regression_4561/src/main.nr +++ b/test_programs/noir_test_success/regression_4561/src/main.nr @@ -60,13 +60,20 @@ struct TestTypeFoo { } #[test] -unconstrained fn complexe_struct_return() { - OracleMock::mock("foo_return").returns(( - 0, [1, 2, 3, 4, 5, 6], 7, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6], +unconstrained fn complex_struct_return() { + let _ = OracleMock::mock("foo_return").returns(( + 0, [[1, 2, 3], [4, 5, 6]], + TestTypeFoo { + a: 7, + b: [ + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]], + ], + c: [[1, 2, 3], [4, 5, 6]], + d: [1, 2, 3], + }, )); - let foo_x = foo_return_unconstrained(); + let foo_x: (Field, [[Field; 3]; 2], TestTypeFoo) = foo_return_unconstrained(); assert_eq((foo_x.0, foo_x.1), (0, [[1, 2, 3], [4, 5, 6]])); assert_eq(foo_x.2.a, 7); assert_eq(