Skip to content

Commit

Permalink
support complex inputs and outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
katat committed Nov 5, 2024
1 parent e30e563 commit 34162a7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 23 deletions.
23 changes: 21 additions & 2 deletions examples/hint.no
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ hint fn exp(const EXP: Field, val: Field) -> Field {
return res;
}

struct Thing {
xx: Field,
yy: Field,
}

hint fn multiple_inputs_outputs(aa: [Field; 2]) -> Thing {
return Thing {
xx: aa[0],
yy: aa[1],
};
}

fn main(pub public_input: Field, private_input: Field) -> Field {
// have to assert these inputs, otherwise it throws vars not in circuit error
assert_eq(public_input, 2);
assert_eq(private_input, 2);

let xx = unsafe add_mul_2(public_input, private_input);
let yy = unsafe mul(public_input, private_input);
assert_eq(xx, yy * 2); // builtin call
assert_eq(xx, yy * 2);

let zz = unsafe div(xx, public_input);
assert_eq(zz, yy);
Expand All @@ -43,5 +56,11 @@ fn main(pub public_input: Field, private_input: Field) -> Field {
log(kk);
assert_eq(kk, 16);

return xx;
let thing = unsafe multiple_inputs_outputs([public_input, 3]);
// have to include all the outputs from hint function, otherwise it throws vars not in circuit error.
// this is because each individual element in the hint output maps to a separate cell var in noname.
assert_eq(thing.xx, public_input);
assert_eq(thing.yy, 3);

return public_input;
}
61 changes: 44 additions & 17 deletions src/circuit_writer/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use circ::{
term,
};
use circ_fields::FieldT;
use kimchi::turshi::helper::CairoFieldHelpers;
use num_bigint::BigUint;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -394,7 +395,6 @@ impl<B: Backend> IRWriter<B> {
let var_info = VarInfo::new(rhs_var, *mutable, typ);

// store the new variable
// TODO: do we really need to store that in the scope? That's not an actual var in the scope that's an internal var...
self.add_local_var(fn_env, lhs.value.clone(), var_info)?;
}

Expand Down Expand Up @@ -534,7 +534,7 @@ impl<B: Backend> IRWriter<B> {
&mut self,
function: &FunctionDef,
args: Vec<crate::circuit_writer::fn_env::VarInfo<B::Field, B::Var>>,
) -> Result<Option<crate::var::Value<B>>> {
) -> Result<Vec<crate::var::Value<B>>> {
assert!(!function.is_main());

// create new fn_env
Expand All @@ -544,26 +544,52 @@ impl<B: Backend> IRWriter<B> {
assert_eq!(function.sig.arguments.len(), args.len());

// create circ var terms for the arguments
let cfg = CircCfg::default();
let cfg_f = cfg.field();
let mut named_args = vec![];
for (arg, observed) in function.sig.arguments.iter().zip(args) {
let name = &arg.name.value;
// create a circ field var term
let var_ir = leaf_term(Op::new_var(name.clone(), Sort::Field(FieldT::FBls12381)));
let var = Var::new_var(var_ir, Span::default());
// create a list of terms corresponding to the observed

let cvars = observed.var.cvars.iter().enumerate().map(|(i, v)| {
// internal var name for IR
let name = format!("{}_{}", name, i);
// map between circ IR variables and noname [ConstOrCell]
named_args.push((name.clone(), v.clone()));

match v {
crate::var::ConstOrCell::Const(cst) => {
let cst: u64 = cst.to_u64();
let cst_v = cfg_f.new_v(cst);
leaf_term(Op::new_const(Value::Field(cst_v)))
}
crate::var::ConstOrCell::Cell(_) => {
leaf_term(Op::new_var(name.clone(), Sort::Field(FieldT::FBls12381)))
}
}
});

let var = Var::new(cvars.collect(), observed.var.span);

// add as local var
let var_info = VarInfo::new(var, false, Some(TyKind::Field { constant: false }));
let var_info = VarInfo::new(var, false, Some(arg.typ.kind.clone()));
self.add_local_var(fn_env, name.clone(), var_info)?;
named_args.push((name.clone(), observed.var.cvars[0].clone()));
}

// compile it and potentially return a return value
self.compile_block(fn_env, &function.body).map(|r| {
r.map(|r| {
let t = r.cvars[0].clone();
crate::var::Value::HintIR(t, named_args)
})
})
let ir = self.compile_block(fn_env, &function.body)?;
if ir.is_none() {
return Ok(vec![]);
}

let res = ir.unwrap().cvars.into_iter().map(|v| {
// With the current setup to calculate symbolic values, the [compute_val] can only compute for one symbolic variable,
// thus it has to evaluate each symbolic variable separately from a hint function.
// Thus, this could introduce some performance overhead if the hint returns multiple symbolic variables.
crate::var::Value::HintIR(v, named_args.clone())
});

Ok(res.collect())
}

fn compile_native_function_call(
Expand Down Expand Up @@ -627,7 +653,7 @@ impl<B: Backend> IRWriter<B> {
let var = var.value(self, fn_env);

let typ = self.expr_type(arg).cloned();
let mutable = false; // TODO: mut keyword in arguments?
let mutable = false;
let var_info = VarInfo::new(var, mutable, typ);

vars.push(var_info);
Expand Down Expand Up @@ -738,7 +764,6 @@ impl<B: Backend> IRWriter<B> {
self.error(ErrorKind::NotAStaticMethod, method_name.span)
})?;

// TODO: for now we pass `self` by value as well
let mutable = false;
let self_var = self_var.value(self, fn_env);

Expand All @@ -755,7 +780,6 @@ impl<B: Backend> IRWriter<B> {
.compute_expr(fn_env, arg)?
.ok_or_else(|| self.error(ErrorKind::CannotComputeExpression, arg.span))?;

// TODO: for now we pass `self` by value as well
let mutable = false;
let var = var.value(self, fn_env);

Expand Down Expand Up @@ -862,7 +886,10 @@ impl<B: Backend> IRWriter<B> {
// Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span),
// Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span),
Op2::Division => {
let t: Term = term![Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(), term![Op::PfUnOp(PfUnOp::Recip); rhs.cvars[0].clone()]];
let t: Term = term![
Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(),
term![Op::PfUnOp(PfUnOp::Recip); rhs.cvars[0].clone()]
];
Var::new_cvar(t, expr.span)
}
_ => todo!(),
Expand Down
18 changes: 14 additions & 4 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,20 @@ impl<B: Backend> CircuitWriter<B> {
self.ir_writer
.compile_hint_function_call(func, vars)
.map(|r| {
r.map(|r| {
let var = self.backend.new_internal_var(r, expr.span);
VarOrRef::Var(Var::new_var(var, expr.span))
})
let cvars: Vec<_> = r
.into_iter()
.map(|r| {
ConstOrCell::Cell(
self.backend.new_internal_var(r, expr.span),
)
})
.collect();

if cvars.is_empty() {
return None;
}

Some(VarOrRef::Var(Var::new(cvars, expr.span)))
})
} else {
self.compile_native_function_call(func, vars)
Expand Down

0 comments on commit 34162a7

Please sign in to comment.