Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 248 additions & 48 deletions src/libfuncs/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ fn build_mul<'ctx, 'this>(
let lhs_value = entry.arg(0)?;
let rhs_value = entry.arg(1)?;

// Extract the ranges for the operands and the result type.
// Extract the ranges for the operands.
let lhs_ty = registry.get_type(&info.signature.param_signatures[0].ty)?;
let rhs_ty = registry.get_type(&info.signature.param_signatures[1].ty)?;

Expand All @@ -351,41 +351,57 @@ fn build_mul<'ctx, 'this>(
.integer_range(registry)?;

let lhs_width = if lhs_ty.is_bounded_int(registry)? {
// P
lhs_range.offset_bit_width()
} else {
lhs_range.zero_based_bit_width()
};
let rhs_width = if rhs_ty.is_bounded_int(registry)? {
// Q
rhs_range.offset_bit_width()
} else {
rhs_range.zero_based_bit_width()
};
let lhs_lower_width = if lhs_range.lower.sign() != Sign::Minus {
// R
lhs_range.lower.bits()
} else {
lhs_range.lower.bits() + 1 // TODO: Check if this is correct
};
let rhs_lower_width = if rhs_range.lower.sign() != Sign::Minus {
// S
rhs_range.lower.bits()
} else {
rhs_range.lower.bits() + 1 // TODO: Check if this is correct
};

// Calculate the computation range.
let compute_range = Range {
lower: (&lhs_range.lower)
.min(&rhs_range.lower)
.min(&dst_range.lower)
.min(&BigInt::ZERO)
.clone(),
upper: (&lhs_range.upper)
.max(&rhs_range.upper)
.max(&dst_range.upper)
.clone(),
let compile_time_val =
lhs_range.lower.clone() * rhs_range.lower.clone() - dst_range.lower.clone();
let w = if compile_time_val.sign() != Sign::Minus {
// W
compile_time_val.bits()
} else {
compile_time_val.bits() + 1 // TODO: Check if this is correct
};
let compute_ty = IntegerType::new(context, compute_range.zero_based_bit_width()).into();

let x = (lhs_width.max(rhs_width) * 2) as u64;
let y = lhs_lower_width.max(rhs_lower_width) * 2;
let z = (rhs_width as u64).max(lhs_lower_width) * 2;

let compute_width = (x.max(y).max(z).max(w) + 3) as u32;
let compute_ty = IntegerType::new(context, compute_width).into();

// Zero-extend operands into the computation range.
native_assert!(
compute_range.offset_bit_width() >= lhs_width,
compute_width >= lhs_width,
"the lhs_range bit_width must be less or equal than the compute_range"
);
native_assert!(
compute_range.offset_bit_width() >= rhs_width,
compute_width >= rhs_width,
"the rhs_range bit_width must be less or equal than the compute_range"
);

let lhs_value = if compute_range.zero_based_bit_width() > lhs_width {
let lhs_value = if compute_width > lhs_width {
if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? {
entry.extui(lhs_value, compute_ty, location)?
} else {
Expand All @@ -394,7 +410,7 @@ fn build_mul<'ctx, 'this>(
} else {
lhs_value
};
let rhs_value = if compute_range.zero_based_bit_width() > rhs_width {
let rhs_value = if compute_width > rhs_width {
if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? {
entry.extui(rhs_value, compute_ty, location)?
} else {
Expand All @@ -404,43 +420,32 @@ fn build_mul<'ctx, 'this>(
rhs_value
};

// Offset the operands so that they are compatible with the operation.
let lhs_value = if lhs_ty.is_bounded_int(registry)? && lhs_range.lower != BigInt::ZERO {
let lhs_offset =
entry.const_int_from_type(context, location, lhs_range.lower, compute_ty)?;
entry.addi(lhs_value, lhs_offset, location)?
} else {
lhs_value
};
let rhs_value = if rhs_ty.is_bounded_int(registry)? && rhs_range.lower != BigInt::ZERO {
let rhs_offset =
entry.const_int_from_type(context, location, rhs_range.lower, compute_ty)?;
entry.addi(rhs_value, rhs_offset, location)?
} else {
rhs_value
};
let ao = entry.const_int_from_type(context, location, lhs_range.lower, compute_ty)?;
let bo = entry.const_int_from_type(context, location, rhs_range.lower, compute_ty)?;

// Compute the operation.
let res_value = entry.muli(lhs_value, rhs_value, location)?;
let ad_bd = entry.muli(lhs_value, rhs_value, location)?;
let ad_bo = entry.muli(lhs_value, bo, location)?;
let bd_ao = entry.muli(rhs_value, ao, location)?;
let compile_time_val =
entry.const_int_from_type(context, location, compile_time_val, compute_ty)?;

// Offset and truncate the result to the output type.
let res_offset = (&dst_range.lower).max(&compute_range.lower).clone();
let res_value = if res_offset != BigInt::ZERO {
let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?;
entry.append_op_result(arith::subi(res_value, res_offset, location))?
} else {
res_value
};
let res_value = entry.addi(ad_bd, ad_bo, location)?;
let res_value = entry.addi(res_value, bd_ao, location)?;
let mut res_value = entry.addi(res_value, compile_time_val, location)?;

let res_value = if dst_range.offset_bit_width() < compute_range.zero_based_bit_width() {
entry.trunci(
if compute_width > dst_range.offset_bit_width() {
res_value = entry.trunci(
res_value,
IntegerType::new(context, dst_range.offset_bit_width()).into(),
location,
)?
} else {
res_value
};
} else if compute_width < dst_range.offset_bit_width() {
res_value = entry.extui(
res_value,
IntegerType::new(context, dst_range.offset_bit_width()).into(),
location,
)?
}

helper.br(entry, 0, &[res_value], location)
}
Expand Down Expand Up @@ -851,13 +856,208 @@ fn build_wrap_non_zero<'ctx, 'this>(

#[cfg(test)]
mod test {
use cairo_lang_sierra::extensions::utils::Range;
use cairo_vm::Felt252;
use num_bigint::BigInt;

use crate::{
context::NativeContext, execution_result::ExecutionResult, executor::JitNativeExecutor,
load_cairo, OptLevel, Value,
jit_enum, jit_struct, load_cairo, utils::testing::run_program_assert_output, OptLevel,
Value,
};

#[test]
fn test_mul() {
let cairo = load_cairo!(
#[feature("bounded-int-utils")]
use core::internal::bounded_int::{self, BoundedInt, MulHelper, mul};

impl MulHelper1 of MulHelper<BoundedInt<-128, 127>, BoundedInt<-128, 127>> {
type Result = BoundedInt<-16256, 16384>;
}

impl MulHelper2 of MulHelper<BoundedInt<0, 128>, BoundedInt<0, 128>> {
type Result = BoundedInt<0, 16384>;
}

impl MulHelper3 of MulHelper<BoundedInt<1, 31>, BoundedInt<1, 1>> {
type Result = BoundedInt<1, 31>;
}

impl MulHelper4 of MulHelper<BoundedInt<-1, 31>, BoundedInt<-1, -1>> {
type Result = BoundedInt<-31, 1>;
}

impl MulHelper5 of MulHelper<BoundedInt<31, 31>, BoundedInt<1, 1>> {
type Result = BoundedInt<31, 31>;
}

impl MulHelper6 of MulHelper<BoundedInt<-100, 0>, BoundedInt<0, 100>> {
type Result = BoundedInt<-10000, 0>;
}

impl MulHelper7 of MulHelper<BoundedInt<1, 1>, BoundedInt<1, 1>> {
type Result = BoundedInt<1, 1>;
}

fn run_test_1(a: felt252, b: felt252) -> BoundedInt<-16256, 16384> {
let a: BoundedInt<-128, 127> = a.try_into().unwrap();
let b: BoundedInt<-128, 127> = b.try_into().unwrap();

mul(a,b)
}

fn run_test_2(a: felt252, b: felt252) -> BoundedInt<0, 16384> {
let a: BoundedInt<0, 128> = a.try_into().unwrap();
let b: BoundedInt<0, 128> = b.try_into().unwrap();

mul(a,b)
}

fn run_test_3(a: felt252, b: felt252) -> BoundedInt<1, 31> {
let a: BoundedInt<1, 31> = a.try_into().unwrap();
let b: BoundedInt<1, 1> = b.try_into().unwrap();

mul(a,b)
}

fn run_test_4(a: felt252, b: felt252) -> BoundedInt<-31, 1> {
let a: BoundedInt<-1, 31> = a.try_into().unwrap();
let b: BoundedInt<-1, -1> = b.try_into().unwrap();

mul(a,b)
}

fn run_test_5(a: felt252, b: felt252) -> BoundedInt<31, 31> {
let a: BoundedInt<31, 31> = a.try_into().unwrap();
let b: BoundedInt<1, 1> = b.try_into().unwrap();

mul(a,b)
}

fn run_test_6(a: felt252, b: felt252) -> BoundedInt<-10000,0> {
let a: BoundedInt<-100, 0> = a.try_into().unwrap();
let b: BoundedInt<0, 100> = b.try_into().unwrap();

mul(a,b)
}
);

run_program_assert_output(
&cairo,
"run_test_1",
&[
Value::Felt252(Felt252::from(-128)),
Value::Felt252(Felt252::from(-128)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(16384),
range: Range {
lower: BigInt::from(-16256),
upper: BigInt::from(16385),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_2",
&[
Value::Felt252(Felt252::from(126)),
Value::Felt252(Felt252::from(128)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(16128),
range: Range {
lower: BigInt::from(0),
upper: BigInt::from(16385),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_3",
&[
Value::Felt252(Felt252::from(31)),
Value::Felt252(Felt252::from(1)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(31),
range: Range {
lower: BigInt::from(1),
upper: BigInt::from(32),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_4",
&[
Value::Felt252(Felt252::from(31)),
Value::Felt252(Felt252::from(-1)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(-31),
range: Range {
lower: BigInt::from(-31),
upper: BigInt::from(2),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_5",
&[
Value::Felt252(Felt252::from(31)),
Value::Felt252(Felt252::from(1)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(31),
range: Range {
lower: BigInt::from(31),
upper: BigInt::from(32),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_6",
&[
Value::Felt252(Felt252::from(-100)),
Value::Felt252(Felt252::from(100)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(-10000),
range: Range {
lower: BigInt::from(-10000),
upper: BigInt::from(1),
}
})
),
);
}

#[test]
fn test_trim_some_pos_i8() {
let (_, program) = load_cairo!(
Expand Down