diff --git a/CHANGELOG.md b/CHANGELOG.md index 84f10e6eba..40a8c26c75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,42 @@ #### Upcoming Changes +* Implement hints on field_arithmetic lib (Part 2) [#1004](https://github.com/lambdaclass/cairo-rs/pull/1004) + + `BuiltinHintProcessor` now supports the following hint: + + ```python + %{ + from starkware.python.math_utils import div_mod + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + b = pack(ids.b, num_bits_shift = 128) + p = pack(ids.p, num_bits_shift = 128) + # For python3.8 and above the modular inverse can be computed as follows: + # b_inverse_mod_p = pow(b, -1, p) + # Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils + b_inverse_mod_p = div_mod(1, b, p) + + + b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + + ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] + ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] + ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2] + %} + ``` + * Optimizations for hash builtin [#1029](https://github.com/lambdaclass/cairo-rs/pull/1029): * Track the verified addresses by offset in a `Vec` rather than storing the address in a `Vec` diff --git a/cairo_programs/field_arithmetic.cairo b/cairo_programs/field_arithmetic.cairo index 9d9c09393e..4386b44d96 100644 --- a/cairo_programs/field_arithmetic.cairo +++ b/cairo_programs/field_arithmetic.cairo @@ -124,9 +124,51 @@ namespace field_arithmetic { } } + // Computes a * b^{-1} modulo p + // NOTE: The modular inverse of b modulo p is computed in a hint and verified outside the hint with a multiplicaiton + func div{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) { + alloc_locals; + local b_inverse_mod_p: Uint384; + %{ + from starkware.python.math_utils import div_mod + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + b = pack(ids.b, num_bits_shift = 128) + p = pack(ids.p, num_bits_shift = 128) + # For python3.8 and above the modular inverse can be computed as follows: + # b_inverse_mod_p = pow(b, -1, p) + # Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils + b_inverse_mod_p = div_mod(1, b, p) + + + b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + + ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] + ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] + ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2] + %} + uint384_lib.check(b_inverse_mod_p); + let (b_times_b_inverse) = mul(b, b_inverse_mod_p, p); + assert b_times_b_inverse = Uint384(1, 0, 0); + + let (res: Uint384) = mul(a, b_inverse_mod_p, p); + return (res,); + } } func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() { + alloc_locals; // Test get_square //Small prime @@ -162,6 +204,25 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B assert r_c.d1 = 0; assert r_c.d2 = 0; + // Test div + // Small inputs + let a = Uint384(25, 0, 0); + let a_div = Uint384(5, 0, 0); + let a_p = Uint384(31, 0, 0); + let (a_r) = field_arithmetic.div(a, a_div, a_p); + assert a_r.d0 = 5; + assert a_r.d1 = 0; + assert a_r.d2 = 0; + + // Cairo Prime + let b = Uint384(1, 0, 5044639098474805171426); + let b_div = Uint384(1, 0, 2); + let b_p = Uint384(1, 0, 604462909807314605178880); + let (b_r) = field_arithmetic.div(b, b_div, b_p); + assert b_r.d0 = 280171807489444591652763463227596156607; + assert b_r.d1 = 122028556426724038784654414222572127555; + assert b_r.d2 = 410614585309032623322981; + return (); } diff --git a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs index 3a08ee93bd..bed25eccb3 100644 --- a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs +++ b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs @@ -84,6 +84,7 @@ use felt::Felt252; #[cfg(feature = "skip_next_instruction_hint")] use crate::hint_processor::builtin_hint_processor::skip_next_instruction::skip_next_instruction; +use super::field_arithmetic::uint384_div; use super::vrf::inv_mod_p_uint512::inv_mod_p_uint512; pub struct HintProcessorData { @@ -562,6 +563,7 @@ impl HintProcessor for BuiltinHintProcessor { hint_code::UINT384_SIGNED_NN => { uint384_signed_nn(vm, &hint_data.ids_data, &hint_data.ap_tracking) } + hint_code::UINT384_DIV => uint384_div(vm, &hint_data.ids_data, &hint_data.ap_tracking), hint_code::UINT256_MUL_DIV_MOD => { uint256_mul_div_mod(vm, &hint_data.ids_data, &hint_data.ap_tracking) } diff --git a/src/hint_processor/builtin_hint_processor/field_arithmetic.rs b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs index ee53e5f7c8..9b72a9788e 100644 --- a/src/hint_processor/builtin_hint_processor/field_arithmetic.rs +++ b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs @@ -1,10 +1,12 @@ use felt::Felt252; -use num_bigint::BigUint; +use num_bigint::{BigUint, ToBigInt}; +use num_integer::Integer; use num_traits::Zero; -use crate::math_utils::{is_quad_residue, sqrt_prime_power}; +use crate::math_utils::{is_quad_residue, mul_inv, sqrt_prime_power}; use crate::serde::deserialize_program::ApTracking; use crate::stdlib::{collections::HashMap, prelude::*}; +use crate::types::errors::math_errors::MathError; use crate::vm::errors::hint_errors::HintError; use crate::{ hint_processor::hint_processor_definition::HintReference, vm::vm_core::VirtualMachine, @@ -112,6 +114,68 @@ pub fn get_square_root( Ok(()) } + +/* Implements Hint: + %{ + from starkware.python.math_utils import div_mod + + def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + a = pack(ids.a, num_bits_shift = 128) + b = pack(ids.b, num_bits_shift = 128) + p = pack(ids.p, num_bits_shift = 128) + # For python3.8 and above the modular inverse can be computed as follows: + # b_inverse_mod_p = pow(b, -1, p) + # Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils + b_inverse_mod_p = div_mod(1, b, p) + + + b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + + ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] + ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] + ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2] +%} + */ +pub fn uint384_div( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + // Note: ids.a is not used here, nor is it used by following hints, so we dont need to extract it. + let b = pack(BigInt3::from_var_name("b", vm, ids_data, ap_tracking)?, 128) + .to_bigint() + .unwrap_or_default(); + let p = pack(BigInt3::from_var_name("p", vm, ids_data, ap_tracking)?, 128) + .to_bigint() + .unwrap_or_default(); + let b_inverse_mod_p_addr = + get_relocatable_from_var_name("b_inverse_mod_p", vm, ids_data, ap_tracking)?; + if b.is_zero() { + return Err(MathError::DividedByZero.into()); + } + let b_inverse_mod_p = mul_inv(&b, &p) + .mod_floor(&p) + .to_biguint() + .unwrap_or_default(); + let b_inverse_mod_p_split = split::<3>(&b_inverse_mod_p, 128); + for (i, b_inverse_mod_p_split) in b_inverse_mod_p_split.iter().enumerate() { + vm.insert_value( + (b_inverse_mod_p_addr + i)?, + Felt252::from(b_inverse_mod_p_split), + )?; + } + Ok(()) +} #[cfg(test)] mod tests { use super::*; @@ -268,4 +332,104 @@ mod tests { ((1, 15), 0) ]; } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_uint384_div_ok() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 11; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), 25), + ((1, 1), 0), + ((1, 2), 0), + //b + ((1, 3), 5), + ((1, 4), 0), + ((1, 5), 0), + //p + ((1, 6), 31), + ((1, 7), 0), + ((1, 8), 0) + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_DIV), Ok(())); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // b_inverse_mod_p + ((1, 9), 25), + ((1, 10), 0), + ((1, 11), 0) + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_uint384_div_b_is_zero() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 11; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), 25), + ((1, 1), 0), + ((1, 2), 0), + //b + ((1, 3), 0), + ((1, 4), 0), + ((1, 5), 0), + //p + ((1, 6), 31), + ((1, 7), 0), + ((1, 8), 0) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_DIV), + Err(HintError::Math(MathError::DividedByZero)) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_uint384_div_inconsistent_memory() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 11; + //Create hint_data + let ids_data = + non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)]; + //Insert ids into memory + vm.segments = segments![ + //a + ((1, 0), 25), + ((1, 1), 0), + ((1, 2), 0), + //b + ((1, 3), 5), + ((1, 4), 0), + ((1, 5), 0), + //p + ((1, 6), 31), + ((1, 7), 0), + ((1, 8), 0), + //b_inverse_mod_p + ((1, 9), 0) + ]; + //Execute the hint + assert_matches!( + run_hint!(vm, ids_data, hint_code::UINT384_DIV), + Err(HintError::Memory(MemoryError::InconsistentMemory(_, _, _))) + ); + } } diff --git a/src/hint_processor/builtin_hint_processor/hint_code.rs b/src/hint_processor/builtin_hint_processor/hint_code.rs index 81f4c3d946..2075633eb7 100644 --- a/src/hint_processor/builtin_hint_processor/hint_code.rs +++ b/src/hint_processor/builtin_hint_processor/hint_code.rs @@ -749,7 +749,7 @@ s = pack(ids.s, PRIME) % N value = res = div_mod(x, s, N)"; pub(crate) const XS_SAFE_DIV: &str = "value = k = safe_div(res * s - x, N)"; -// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib/uint384.cairo +// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib pub const UINT384_UNSIGNED_DIV_REM: &str = "def split(num: int, num_bits_shift: int, length: int): a = [] for _ in range(length): @@ -914,6 +914,33 @@ ids.sqrt_x.d2 = split_root_x[2] ids.sqrt_gx.d0 = split_root_gx[0] ids.sqrt_gx.d1 = split_root_gx[1] ids.sqrt_gx.d2 = split_root_gx[2]"; +pub const UINT384_DIV: &str = "from starkware.python.math_utils import div_mod + +def split(num: int, num_bits_shift: int, length: int): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + +def pack(z, num_bits_shift: int) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + +a = pack(ids.a, num_bits_shift = 128) +b = pack(ids.b, num_bits_shift = 128) +p = pack(ids.p, num_bits_shift = 128) +# For python3.8 and above the modular inverse can be computed as follows: +# b_inverse_mod_p = pow(b, -1, p) +# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils +b_inverse_mod_p = div_mod(1, b, p) + + +b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3) + +ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0] +ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1] +ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]"; pub const HI_MAX_BITLEN: &str = "ids.len_hi = max(ids.scalar_u.d2.bit_length(), ids.scalar_v.d2.bit_length())-1"; diff --git a/src/math_utils.rs b/src/math_utils.rs index 8ca9166102..ae8ba6668c 100644 --- a/src/math_utils.rs +++ b/src/math_utils.rs @@ -83,7 +83,7 @@ pub fn safe_div_usize(x: usize, y: usize) -> Result { } ///Returns num_a^-1 mod p -fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt { +pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt { if num_a.is_zero() { return BigInt::zero(); } diff --git a/src/tests/cairo_run_test.rs b/src/tests/cairo_run_test.rs index c34a934e5c..748799f1e8 100644 --- a/src/tests/cairo_run_test.rs +++ b/src/tests/cairo_run_test.rs @@ -712,7 +712,7 @@ fn uint384_extension() { #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn field_arithmetic() { let program_data = include_bytes!("../../cairo_programs/field_arithmetic.json"); - run_program_simple_with_memory_holes(program_data.as_slice(), 192); + run_program_simple_with_memory_holes(program_data.as_slice(), 272); } #[test]