Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builtin operators #197

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/comparator.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fn main(pub xx: Field, yy: Field) {
let res = xx < yy;
assert(res);
}
3 changes: 3 additions & 0 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ impl<B: Backend> CircuitWriter<B> {
Op2::Multiplication => field::mul(self, &lhs[0], &rhs[0], expr.span),
Op2::Equality => field::equal(self, &lhs, &rhs, expr.span),
Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span),
// todo: refactor the input vars from Var to VarInfo,
// which contain the type to provide the info about the bit length
Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span),
Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span),
Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span),
Op2::Division => todo!(),
Expand Down
115 changes: 115 additions & 0 deletions src/constraints/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,121 @@ pub fn not_equal<B: Backend>(
acc
}

/// Returns 1 if lhs < rhs, 0 otherwise
pub fn less_than<B: Backend>(
compiler: &mut CircuitWriter<B>,
bitlen: Option<usize>,
lhs: &ConstOrCell<B::Field, B::Var>,
rhs: &ConstOrCell<B::Field, B::Var>,
span: Span,
) -> Var<B::Field, B::Var> {
let one = B::Field::one();
let zero = B::Field::zero();

// Instead of comparing bit by bit, we check the carry bit:
// lhs + (1 << LEN) - rhs
// proof:
// lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs,
// resulted in a bit array of length LEN + 1, named as sum_bits.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand this but intuitively this is not going to work when we are too close to the modulus

// if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)`
// then, the carry bit of sum_bits is 0.
// otherwise, the carry bit of sum_bits is 1.

/*
psuedo code:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pseudo*

let carry_bit_len = LEN + 1;

# 1 << LEN
let mut pow2 = 1;
for ii in 0..LEN {
pow2 = pow2 + pow2;
}

let sum = (pow2 + lhs) - rhs;
let sum_bit = bits::to_bits(carry_bit_len, sum);

let b1 = false;
let b2 = true;
let res = if sum_bit[LEN] { b1 } else { b2 };

*/

let modulus_bits: usize = B::Field::modulus_biguint()
.bits()
.try_into()
.expect("can't determine the number of bits in the modulus");

let bitlen_upper_bound = modulus_bits - 2;
let bit_len = bitlen.unwrap_or(bitlen_upper_bound);

assert!(bit_len <= (bitlen_upper_bound));


let carry_bit_len = bit_len + 1;


// let pow2 = (1 << bit_len) as u32;
// let pow2 = B::Field::from(pow2);
let two = B::Field::from(2u32);
let pow2 = two.pow([bit_len as u64]);

// let pow2_lhs = compiler.backend.add_const(lhs, &pow2, span);
match (lhs, rhs) {
(ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => {
let res = if lhs < rhs { one } else { zero };

Var::new_constant(res, span)
}
(_, _) => {
let pow2_lhs = match lhs {
// todo: we really should refactor the backend to handle ConstOrCell
ConstOrCell::Const(lhs) => compiler.backend.add_constant(
Some("wrap a constant as var"),
*lhs + pow2,
span,
),
ConstOrCell::Cell(lhs) => compiler.backend.add_const(lhs, &pow2, span),
};

let rhs = match rhs {
ConstOrCell::Const(rhs) => compiler.backend.add_constant(
Some("wrap a constant as var"),
*rhs,
span,
),
ConstOrCell::Cell(rhs) => rhs.clone(),
};

let sum = compiler.backend.sub(&pow2_lhs, &rhs, span);

// todo: this api call is kind of weird here, maybe these bulitin shouldn't get inputs from the `GenericParameters`
let generic_var_name = "LEN".to_string();
let mut gens = GenericParameters::default();
gens.add(generic_var_name.clone());
gens.assign(&generic_var_name, carry_bit_len as u32, span)
.unwrap();

// construct var info for sum
let cbl_var = Var::new_constant(B::Field::from(carry_bit_len as u32), span);
let cbl_var = VarInfo::new(cbl_var, false, Some(TyKind::Field { constant: true }));

let sum_var = Var::new_var(sum, span);
let sum_var = VarInfo::new(sum_var, false, Some(TyKind::Field { constant: false }));

let sum_bits = to_bits(compiler, &gens, &[cbl_var, sum_var], span).unwrap().unwrap();
// convert to cell vars
let sum_bits: Vec<_> = sum_bits.cvars.into_iter().collect();

// if sum_bits[LEN] == 0, then lhs < rhs
let res = &is_zero_cell(compiler, &sum_bits[bit_len], span)[0];
let res = res
.cvar()
.unwrap();
Var::new_var(res.clone(), span)
}
}
}

/// Returns 1 if var is zero, 0 otherwise
fn is_zero_cell<B: Backend>(
compiler: &mut CircuitWriter<B>,
Expand Down
5 changes: 3 additions & 2 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,9 @@ fn monomorphize_expr<B: Backend>(
let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?;

let typ = match op {
Op2::Equality => Some(TyKind::Bool),
Op2::Inequality => Some(TyKind::Bool),
Op2::Equality
| Op2::Inequality
| Op2::LessThan => Some(TyKind::Bool),
Op2::Addition
| Op2::Subtraction
| Op2::Multiplication
Expand Down
3 changes: 3 additions & 0 deletions src/parser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub enum Op2 {
Division,
Equality,
Inequality,
LessThan,
BoolAnd,
BoolOr,
}
Expand Down Expand Up @@ -438,6 +439,7 @@ impl Expr {
| TokenKind::Slash
| TokenKind::DoubleEqual
| TokenKind::NotEqual
| TokenKind::Less
| TokenKind::DoubleAmpersand
| TokenKind::DoublePipe
| TokenKind::Exclamation,
Expand All @@ -452,6 +454,7 @@ impl Expr {
TokenKind::Slash => Op2::Division,
TokenKind::DoubleEqual => Op2::Equality,
TokenKind::NotEqual => Op2::Inequality,
TokenKind::Less => Op2::LessThan,
TokenKind::DoubleAmpersand => Op2::BoolAnd,
TokenKind::DoublePipe => Op2::BoolOr,
_ => unreachable!(),
Expand Down
12 changes: 12 additions & 0 deletions src/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,18 @@ fn test_not_equal(#[case] backend: BackendKind) -> miette::Result<()> {
Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_comparator(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"xx": "1"}"#;
let private_inputs = r#"{"yy": "2"}"#;

test_file("comparator", public_inputs, private_inputs, vec![], backend)?;

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
Expand Down
5 changes: 3 additions & 2 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,9 @@ impl<B: Backend> TypeChecker<B> {
}

let typ = match op {
Op2::Equality => TyKind::Bool,
Op2::Inequality => TyKind::Bool,
Op2::Equality
| Op2::Inequality
| Op2::LessThan => TyKind::Bool,
Op2::Addition
| Op2::Subtraction
| Op2::Multiplication
Expand Down