From 34162a764b2eaac14520fd0a2176d570b7a1465f Mon Sep 17 00:00:00 2001 From: kata Date: Tue, 5 Nov 2024 16:35:57 +0800 Subject: [PATCH] support complex inputs and outputs --- examples/hint.no | 23 ++++++++++++-- src/circuit_writer/ir.rs | 61 ++++++++++++++++++++++++++---------- src/circuit_writer/writer.rs | 18 ++++++++--- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/examples/hint.no b/examples/hint.no index c9c3971f2..12bdf953f 100644 --- a/examples/hint.no +++ b/examples/hint.no @@ -25,13 +25,26 @@ hint fn exp(const EXP: Field, val: Field) -> Field { return res; } +struct Thing { + xx: Field, + yy: Field, +} + +hint fn multiple_inputs_outputs(aa: [Field; 2]) -> Thing { + return Thing { + xx: aa[0], + yy: aa[1], + }; +} + fn main(pub public_input: Field, private_input: Field) -> Field { + // have to assert these inputs, otherwise it throws vars not in circuit error assert_eq(public_input, 2); assert_eq(private_input, 2); let xx = unsafe add_mul_2(public_input, private_input); let yy = unsafe mul(public_input, private_input); - assert_eq(xx, yy * 2); // builtin call + assert_eq(xx, yy * 2); let zz = unsafe div(xx, public_input); assert_eq(zz, yy); @@ -43,5 +56,11 @@ fn main(pub public_input: Field, private_input: Field) -> Field { log(kk); assert_eq(kk, 16); - return xx; + let thing = unsafe multiple_inputs_outputs([public_input, 3]); + // have to include all the outputs from hint function, otherwise it throws vars not in circuit error. + // this is because each individual element in the hint output maps to a separate cell var in noname. + assert_eq(thing.xx, public_input); + assert_eq(thing.yy, 3); + + return public_input; } \ No newline at end of file diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index ab6f9c6bc..88e1ea985 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -4,6 +4,7 @@ use circ::{ term, }; use circ_fields::FieldT; +use kimchi::turshi::helper::CairoFieldHelpers; use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -394,7 +395,6 @@ impl IRWriter { let var_info = VarInfo::new(rhs_var, *mutable, typ); // store the new variable - // TODO: do we really need to store that in the scope? That's not an actual var in the scope that's an internal var... self.add_local_var(fn_env, lhs.value.clone(), var_info)?; } @@ -534,7 +534,7 @@ impl IRWriter { &mut self, function: &FunctionDef, args: Vec>, - ) -> Result>> { + ) -> Result>> { assert!(!function.is_main()); // create new fn_env @@ -544,26 +544,52 @@ impl IRWriter { assert_eq!(function.sig.arguments.len(), args.len()); // create circ var terms for the arguments + let cfg = CircCfg::default(); + let cfg_f = cfg.field(); let mut named_args = vec![]; for (arg, observed) in function.sig.arguments.iter().zip(args) { let name = &arg.name.value; - // create a circ field var term - let var_ir = leaf_term(Op::new_var(name.clone(), Sort::Field(FieldT::FBls12381))); - let var = Var::new_var(var_ir, Span::default()); + // create a list of terms corresponding to the observed + + let cvars = observed.var.cvars.iter().enumerate().map(|(i, v)| { + // internal var name for IR + let name = format!("{}_{}", name, i); + // map between circ IR variables and noname [ConstOrCell] + named_args.push((name.clone(), v.clone())); + + match v { + crate::var::ConstOrCell::Const(cst) => { + let cst: u64 = cst.to_u64(); + let cst_v = cfg_f.new_v(cst); + leaf_term(Op::new_const(Value::Field(cst_v))) + } + crate::var::ConstOrCell::Cell(_) => { + leaf_term(Op::new_var(name.clone(), Sort::Field(FieldT::FBls12381))) + } + } + }); + + let var = Var::new(cvars.collect(), observed.var.span); // add as local var - let var_info = VarInfo::new(var, false, Some(TyKind::Field { constant: false })); + let var_info = VarInfo::new(var, false, Some(arg.typ.kind.clone())); self.add_local_var(fn_env, name.clone(), var_info)?; - named_args.push((name.clone(), observed.var.cvars[0].clone())); } // compile it and potentially return a return value - self.compile_block(fn_env, &function.body).map(|r| { - r.map(|r| { - let t = r.cvars[0].clone(); - crate::var::Value::HintIR(t, named_args) - }) - }) + let ir = self.compile_block(fn_env, &function.body)?; + if ir.is_none() { + return Ok(vec![]); + } + + let res = ir.unwrap().cvars.into_iter().map(|v| { + // With the current setup to calculate symbolic values, the [compute_val] can only compute for one symbolic variable, + // thus it has to evaluate each symbolic variable separately from a hint function. + // Thus, this could introduce some performance overhead if the hint returns multiple symbolic variables. + crate::var::Value::HintIR(v, named_args.clone()) + }); + + Ok(res.collect()) } fn compile_native_function_call( @@ -627,7 +653,7 @@ impl IRWriter { let var = var.value(self, fn_env); let typ = self.expr_type(arg).cloned(); - let mutable = false; // TODO: mut keyword in arguments? + let mutable = false; let var_info = VarInfo::new(var, mutable, typ); vars.push(var_info); @@ -738,7 +764,6 @@ impl IRWriter { self.error(ErrorKind::NotAStaticMethod, method_name.span) })?; - // TODO: for now we pass `self` by value as well let mutable = false; let self_var = self_var.value(self, fn_env); @@ -755,7 +780,6 @@ impl IRWriter { .compute_expr(fn_env, arg)? .ok_or_else(|| self.error(ErrorKind::CannotComputeExpression, arg.span))?; - // TODO: for now we pass `self` by value as well let mutable = false; let var = var.value(self, fn_env); @@ -862,7 +886,10 @@ impl IRWriter { // Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), // Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), Op2::Division => { - let t: Term = term![Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(), term![Op::PfUnOp(PfUnOp::Recip); rhs.cvars[0].clone()]]; + let t: Term = term![ + Op::PfNaryOp(PfNaryOp::Mul); lhs.cvars[0].clone(), + term![Op::PfUnOp(PfUnOp::Recip); rhs.cvars[0].clone()] + ]; Var::new_cvar(t, expr.span) } _ => todo!(), diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 076b4684e..a2a56ca9b 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -438,10 +438,20 @@ impl CircuitWriter { self.ir_writer .compile_hint_function_call(func, vars) .map(|r| { - r.map(|r| { - let var = self.backend.new_internal_var(r, expr.span); - VarOrRef::Var(Var::new_var(var, expr.span)) - }) + let cvars: Vec<_> = r + .into_iter() + .map(|r| { + ConstOrCell::Cell( + self.backend.new_internal_var(r, expr.span), + ) + }) + .collect(); + + if cvars.is_empty() { + return None; + } + + Some(VarOrRef::Var(Var::new(cvars, expr.span))) }) } else { self.compile_native_function_call(func, vars)