diff --git a/Cargo.toml b/Cargo.toml index 803beade..e5bbb0d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,14 +10,18 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" +hex = "0.4" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } # system_halo2 halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } # loader_evm -ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10.1", optional = true } +ethereum_types = { package = "ethereum-types", version = "0.13", default-features = false, features = ["std"], optional = true } +sha3 = { version = "0.10", optional = true } +revm = { version = "2.1.0", optional = true } +bytes = { version = "1.2", optional = true } +rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } @@ -31,14 +35,13 @@ paste = "1.0.7" halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm -foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "6b1ee60e" } -crossterm = { version = "0.22.1" } -tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] default = ["loader_evm", "loader_halo2", "system_halo2"] -loader_evm = ["dep:ethereum_types", "dep:sha3"] +loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 66fa8ddf..10a7b269 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -1,5 +1,4 @@ use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ dev::MockProver, @@ -18,7 +17,7 @@ use halo2_proofs::{ use itertools::Itertools; use plonk_verifier::{ loader::{ - evm::{encode_calldata, EvmLoader}, + evm::{encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, @@ -167,7 +166,7 @@ mod aggregation { use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, ConstraintSystem}, + plonk::{self, Circuit, ConstraintSystem, Error}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; use halo2_wrong_ecc::{ @@ -182,7 +181,7 @@ mod aggregation { use plonk_verifier::{ loader::{self, native::NativeLoader}, pcs::{ - kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, + kzg::{KzgAccumulator, KzgSuccinctVerifyingKey, LimbsEncodingInstructions}, AccumulationScheme, AccumulationSchemeProver, }, system, @@ -191,7 +190,7 @@ mod aggregation { Protocol, }; use rand::rngs::OsRng; - use std::{iter, rc::Rc}; + use std::rc::Rc; const T: usize = 5; const RATE: usize = 4; @@ -434,28 +433,33 @@ mod aggregation { range_chip.load_table(&mut layouter)?; - let (lhs, rhs) = layouter.assign_region( + let accumulator_limbs = layouter.assign_region( || "", |region| { let ctx = RegionCtx::new(region, 0); let ecc_chip = config.ecc_chip(); let loader = Halo2Loader::new(ecc_chip, ctx); - let KzgAccumulator { lhs, rhs } = - aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); + let accumulator = aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); - Ok((lhs.assigned(), rhs.assigned())) + let accumulator_limbs = [accumulator.lhs, accumulator.rhs] + .iter() + .map(|ec_point| { + loader.ecc_chip().assign_ec_point_to_limbs( + &mut loader.ctx_mut(), + ec_point.assigned(), + ) + }) + .collect::, Error>>()? + .into_iter() + .flatten(); + + Ok(accumulator_limbs) }, )?; - for (limb, row) in iter::empty() - .chain(lhs.x().limbs()) - .chain(lhs.y().limbs()) - .chain(rhs.x().limbs()) - .chain(rhs.y().limbs()) - .zip(0..) - { - main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + for (row, limb) in accumulator_limbs.enumerate() { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; } Ok(()) @@ -574,16 +578,14 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) let success = { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); + .build(); let caller = Address::from_low_u64_be(0xfe); let verifier = evm - .deploy(caller, deployment_code.into(), 0.into(), None) - .unwrap() - .address; - let result = evm - .call_raw(caller, verifier, calldata.into(), 0.into()) + .deploy(caller, deployment_code.into(), 0.into()) + .address .unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); dbg!(result.gas_used); diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index 9ed70b36..4ff5c682 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -1,5 +1,4 @@ use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, @@ -21,7 +20,7 @@ use halo2_proofs::{ }; use itertools::Itertools; use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader}, + loader::evm::{encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, @@ -231,16 +230,14 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) let success = { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); + .build(); let caller = Address::from_low_u64_be(0xfe); let verifier = evm - .deploy(caller, deployment_code.into(), 0.into(), None) - .unwrap() - .address; - let result = evm - .call_raw(caller, verifier, calldata.into(), 0.into()) + .deploy(caller, deployment_code.into(), 0.into()) + .address .unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); dbg!(result.gas_used); diff --git a/src/loader.rs b/src/loader.rs index 5a040f9f..297390d0 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -5,7 +5,7 @@ use crate::{ }, Error, }; -use std::{fmt::Debug, iter}; +use std::{borrow::Cow, fmt::Debug, iter, ops::Deref}; pub mod native; @@ -86,7 +86,7 @@ pub trait EcPointLoader { ) -> Result<(), Error>; fn multi_scalar_multiplication( - pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)], + pairs: &[(&Self::LoadedScalar, &Self::LoadedEcPoint)], ) -> Self::LoadedEcPoint where Self: ScalarLoader; @@ -126,17 +126,18 @@ pub trait ScalarLoader { .chain(if constant == F::zero() { None } else { - Some(loader.load_const(&constant)) + Some(Cow::Owned(loader.load_const(&constant))) }) .chain(values.iter().map(|&(coeff, value)| { if coeff == F::one() { - value.clone() + Cow::Borrowed(value) } else { - loader.load_const(&coeff) * value + Cow::Owned(loader.load_const(&coeff) * value) } })) - .reduce(|acc, term| acc + term) + .reduce(|acc, term| Cow::Owned(acc.into_owned() + term.deref())) .unwrap() + .into_owned() } fn sum_products_with_coeff_and_const( diff --git a/src/loader/evm.rs b/src/loader/evm.rs index 7a07670c..fa80b97e 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -6,7 +6,9 @@ mod util; mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; -pub use util::{encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, MemoryChunk}; +pub use util::{ + encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, MemoryChunk, +}; pub use ethereum_types::U256; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 7d1dbb94..9e8b8e2a 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -37,7 +37,7 @@ impl PartialEq for Value { impl Value { fn identifier(&self) -> String { - match &self { + match self { Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), Value::Negated(value) => format!("-({:?})", value), Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), @@ -222,13 +222,13 @@ impl EvmLoader { pub fn ec_point_from_limbs( self: &Rc, - x_limbs: [Scalar; LIMBS], - y_limbs: [Scalar; LIMBS], + x_limbs: [&Scalar; LIMBS], + y_limbs: [&Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { for (idx, limb) in limbs.into_iter().enumerate() { - self.push(&limb); + self.push(limb); // [..., success, acc] if idx > 0 { self.code @@ -769,10 +769,11 @@ where } fn multi_scalar_multiplication( - pairs: &[(>::LoadedScalar, EcPoint)], + pairs: &[(&>::LoadedScalar, &EcPoint)], ) -> EcPoint { pairs .iter() + .cloned() .map(|(scalar, ec_point)| match scalar.value { Value::Constant(constant) if U256::one() == constant => ec_point.clone(), _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index e204c1b8..f81c5f54 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -1,10 +1,9 @@ -use crate::{loader::evm::test::tui::Tui, util::Itertools}; -use foundry_evm::{ - executor::{backend::Backend, fork::MultiFork, ExecutorBuilder}, - revm::{AccountInfo, Bytecode}, - utils::h256_to_u256_be, - Address, +use crate::{ + loader::evm::{test::tui::Tui, util::ExecutorBuilder}, + util::Itertools, }; +use ethereum_types::{Address, U256}; +use revm::{AccountInfo, Bytecode}; use std::env::var_os; mod tui; @@ -29,23 +28,20 @@ pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .set_tracing(debug) .set_debugger(debug) - .build(Backend::new(MultiFork::new().0, None)); + .build(); - evm.backend_mut().insert_account_info( + evm.db_mut().insert_account_info( callee, AccountInfo::new(0.into(), 1, Bytecode::new_raw(code.into())), ); - let result = evm - .call_raw(caller, callee, calldata.into(), 0.into()) - .unwrap(); + let result = evm.call_raw(caller, callee, calldata.into(), 0.into()); let costs = result .logs .into_iter() - .map(|log| h256_to_u256_be(log.topics[0]).as_u64()) + .map(|log| U256::from_big_endian(log.topics[0].as_bytes()).as_u64()) .collect_vec(); if debug { diff --git a/src/loader/evm/test/tui.rs b/src/loader/evm/test/tui.rs index fcaef36c..c0c4d7f8 100644 --- a/src/loader/evm/test/tui.rs +++ b/src/loader/evm/test/tui.rs @@ -1,5 +1,6 @@ //! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/ui/src/lib.rs +use crate::loader::evm::util::executor::{CallKind, DebugStep}; use crossterm::{ event::{ self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode, KeyEvent, KeyModifiers, @@ -8,11 +9,8 @@ use crossterm::{ execute, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; -use foundry_evm::{ - debug::{DebugStep, Instruction}, - revm::opcode, - Address, CallKind, -}; +use ethereum_types::Address; +use revm::opcode; use std::{ cmp::{max, min}, io, @@ -90,7 +88,7 @@ impl Tui { self.terminal.clear().unwrap(); let mut draw_memory: DrawMemory = DrawMemory::default(); - let debug_call: Vec<(Address, Vec, CallKind)> = self.debug_arena.clone(); + let debug_call = &self.debug_arena; let mut opcode_list: Vec = debug_call[0] .1 .iter() @@ -207,7 +205,7 @@ impl Tui { } KeyCode::Char('s') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let remaining_ops = opcode_list[self.current_step..].to_vec().clone(); + let remaining_ops = &opcode_list[self.current_step..]; self.current_step += remaining_ops .iter() .enumerate() @@ -233,7 +231,7 @@ impl Tui { } KeyCode::Char('a') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let prev_ops = opcode_list[..self.current_step].to_vec().clone(); + let prev_ops = &opcode_list[..self.current_step]; self.current_step = prev_ops .iter() .enumerate() @@ -618,12 +616,7 @@ impl Tui { .borders(Borders::ALL); let min_len = usize::max(format!("{}", stack.len()).len(), 2); - let indices_affected = - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - stack_indices_affected(op) - } else { - vec![] - }; + let indices_affected = stack_indices_affected(debug_steps[current_step].instruction.0); let text: Vec = stack .iter() @@ -699,33 +692,29 @@ impl Tui { let mut word = None; let mut color = None; - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - let stack_len = debug_steps[current_step].stack.len(); - if stack_len > 0 { - let w = debug_steps[current_step].stack[stack_len - 1]; - match op { - opcode::MLOAD => { - word = Some(w.as_usize() / 32); - color = Some(Color::Cyan); - } - opcode::MSTORE => { - word = Some(w.as_usize() / 32); - color = Some(Color::Red); - } - _ => {} + let stack_len = debug_steps[current_step].stack.len(); + if stack_len > 0 { + let w = debug_steps[current_step].stack[stack_len - 1]; + match debug_steps[current_step].instruction.0 { + opcode::MLOAD => { + word = Some(w.as_usize() / 32); + color = Some(Color::Cyan); } + opcode::MSTORE => { + word = Some(w.as_usize() / 32); + color = Some(Color::Red); + } + _ => {} } } if current_step > 0 { let prev_step = current_step - 1; let stack_len = debug_steps[prev_step].stack.len(); - if let Instruction::OpCode(op) = debug_steps[prev_step].instruction { - if op == opcode::MSTORE { - let prev_top = debug_steps[prev_step].stack[stack_len - 1]; - word = Some(prev_top.as_usize() / 32); - color = Some(Color::Green); - } + if debug_steps[prev_step].instruction.0 == opcode::MSTORE { + let prev_top = debug_steps[prev_step].stack[stack_len - 1]; + word = Some(prev_top.as_usize() / 32); + color = Some(Color::Green); } } diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs index 0d9698bd..0a772513 100644 --- a/src/loader/evm/util.rs +++ b/src/loader/evm/util.rs @@ -5,6 +5,10 @@ use crate::{ use ethereum_types::U256; use std::iter; +pub(crate) mod executor; + +pub use executor::ExecutorBuilder; + pub struct MemoryChunk { ptr: usize, len: usize, diff --git a/src/loader/evm/util/executor.rs b/src/loader/evm/util/executor.rs new file mode 100644 index 00000000..ec9695e0 --- /dev/null +++ b/src/loader/evm/util/executor.rs @@ -0,0 +1,868 @@ +//! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/evm/src/executor/mod.rs + +use bytes::Bytes; +use ethereum_types::{Address, H256, U256, U64}; +use revm::{ + evm_inner, opcode, spec_opcode_gas, Account, BlockEnv, CallInputs, CallScheme, CreateInputs, + CreateScheme, Database, DatabaseCommit, EVMData, Env, ExecutionResult, Gas, GasInspector, + InMemoryDB, Inspector, Interpreter, Memory, OpCode, Return, TransactOut, TransactTo, TxEnv, +}; +use sha3::{Digest, Keccak256}; +use std::{cell::RefCell, collections::HashMap, fmt::Display, rc::Rc}; + +macro_rules! return_ok { + () => { + Return::Continue | Return::Stop | Return::Return | Return::SelfDestruct + }; +} + +fn keccak256(data: impl AsRef<[u8]>) -> [u8; 32] { + Keccak256::digest(data.as_ref()).into() +} + +fn get_contract_address(sender: impl Into
, nonce: impl Into) -> Address { + let mut stream = rlp::RlpStream::new(); + stream.begin_list(2); + stream.append(&sender.into()); + stream.append(&nonce.into()); + + let hash = keccak256(&stream.out()); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create2_address( + from: impl Into
, + salt: [u8; 32], + init_code: impl Into, +) -> Address { + get_create2_address_from_hash(from, salt, keccak256(init_code.into().as_ref()).to_vec()) +} + +fn get_create2_address_from_hash( + from: impl Into
, + salt: [u8; 32], + init_code_hash: impl Into, +) -> Address { + let bytes = [ + &[0xff], + from.into().as_bytes(), + salt.as_slice(), + init_code_hash.into().as_ref(), + ] + .concat(); + + let hash = keccak256(&bytes); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create_address(call: &CreateInputs, nonce: u64) -> Address { + match call.scheme { + CreateScheme::Create => get_contract_address(call.caller, nonce), + CreateScheme::Create2 { salt } => { + let mut buffer: [u8; 4 * 8] = [0; 4 * 8]; + salt.to_big_endian(&mut buffer); + get_create2_address(call.caller, buffer, call.init_code.clone()) + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Log { + pub address: Address, + pub topics: Vec, + pub data: Bytes, + pub block_hash: Option, + pub block_number: Option, + pub transaction_hash: Option, + pub transaction_index: Option, + pub log_index: Option, + pub transaction_log_index: Option, + pub log_type: Option, + pub removed: Option, +} + +#[derive(Clone, Debug, Default)] +struct LogCollector { + logs: Vec, +} + +impl Inspector for LogCollector { + fn log(&mut self, _: &mut EVMData<'_, DB>, address: &Address, topics: &[H256], data: &Bytes) { + self.logs.push(Log { + address: *address, + topics: topics.to_vec(), + data: data.clone(), + ..Default::default() + }); + } + + fn call( + &mut self, + _: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } +} + +#[derive(Clone, Debug, Copy)] +pub enum CallKind { + Call, + StaticCall, + CallCode, + DelegateCall, + Create, + Create2, +} + +impl Default for CallKind { + fn default() -> Self { + CallKind::Call + } +} + +impl From for CallKind { + fn from(scheme: CallScheme) -> Self { + match scheme { + CallScheme::Call => CallKind::Call, + CallScheme::StaticCall => CallKind::StaticCall, + CallScheme::CallCode => CallKind::CallCode, + CallScheme::DelegateCall => CallKind::DelegateCall, + } + } +} + +impl From for CallKind { + fn from(create: CreateScheme) -> Self { + match create { + CreateScheme::Create => CallKind::Create, + CreateScheme::Create2 { .. } => CallKind::Create2, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugArena { + pub arena: Vec, +} + +impl DebugArena { + fn push_node(&mut self, mut new_node: DebugNode) -> usize { + fn recursively_push( + arena: &mut Vec, + entry: usize, + mut new_node: DebugNode, + ) -> usize { + match new_node.depth { + _ if arena[entry].depth == new_node.depth - 1 => { + let id = arena.len(); + new_node.location = arena[entry].children.len(); + new_node.parent = Some(entry); + arena[entry].children.push(id); + arena.push(new_node); + id + } + _ => { + let child = *arena[entry].children.last().unwrap(); + recursively_push(arena, child, new_node) + } + } + } + + if self.arena.is_empty() { + self.arena.push(new_node); + 0 + } else if new_node.depth == 0 { + let id = self.arena.len(); + new_node.location = self.arena[0].children.len(); + new_node.parent = Some(0); + self.arena[0].children.push(id); + self.arena.push(new_node); + id + } else { + recursively_push(&mut self.arena, 0, new_node) + } + } + + #[cfg(test)] + pub fn flatten(&self, entry: usize) -> Vec<(Address, Vec, CallKind)> { + let node = &self.arena[entry]; + + let mut flattened = vec![]; + if !node.steps.is_empty() { + flattened.push((node.address, node.steps.clone(), node.kind)); + } + flattened.extend(node.children.iter().flat_map(|child| self.flatten(*child))); + + flattened + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugNode { + pub parent: Option, + pub children: Vec, + pub location: usize, + pub address: Address, + pub kind: CallKind, + pub depth: usize, + pub steps: Vec, +} + +#[derive(Clone, Debug)] +pub struct DebugStep { + pub stack: Vec, + pub memory: Memory, + pub instruction: Instruction, + pub push_bytes: Option>, + pub pc: usize, + pub total_gas_used: u64, +} + +impl Default for DebugStep { + fn default() -> Self { + Self { + stack: vec![], + memory: Memory::new(), + instruction: Instruction(revm::opcode::INVALID), + push_bytes: None, + pc: 0, + total_gas_used: 0, + } + } +} + +impl DebugStep { + #[cfg(test)] + pub fn pretty_opcode(&self) -> String { + if let Some(push_bytes) = &self.push_bytes { + format!("{}(0x{})", self.instruction, hex::encode(push_bytes)) + } else { + self.instruction.to_string() + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Instruction(pub u8); + +impl From for Instruction { + fn from(op: u8) -> Instruction { + Instruction(op) + } +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + OpCode::try_from_u8(self.0).map_or_else( + || format!("UNDEFINED(0x{:02x})", self.0), + |opcode| opcode.as_str().to_string(), + ) + ) + } +} + +#[derive(Clone, Debug)] +struct Debugger { + arena: DebugArena, + head: usize, + context: Address, + gas_inspector: Rc>, +} + +impl Debugger { + fn new(gas_inspector: Rc>) -> Self { + Self { + arena: Default::default(), + head: Default::default(), + context: Default::default(), + gas_inspector, + } + } + + fn enter(&mut self, depth: usize, address: Address, kind: CallKind) { + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + + fn exit(&mut self) { + if let Some(parent_id) = self.arena.arena[self.head].parent { + let DebugNode { + depth, + address, + kind, + .. + } = self.arena.arena[parent_id]; + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + } +} + +impl Inspector for Debugger { + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + _is_static: bool, + ) -> Return { + let pc = interpreter.program_counter(); + let op = interpreter.contract.bytecode.bytecode()[pc]; + + let opcode_infos = spec_opcode_gas(data.env.cfg.spec_id); + let opcode_info = &opcode_infos[op as usize]; + + let push_size = if opcode_info.is_push() { + (op - opcode::PUSH1 + 1) as usize + } else { + 0 + }; + let push_bytes = match push_size { + 0 => None, + n => { + let start = pc + 1; + let end = start + n; + Some(interpreter.contract.bytecode.bytecode()[start..end].to_vec()) + } + }; + + let spent = interpreter.gas.limit() - self.gas_inspector.borrow().gas_remaining(); + let total_gas_used = spent - (interpreter.gas.refunded() as u64).min(spent / 5); + + self.arena.arena[self.head].steps.push(DebugStep { + pc, + stack: interpreter.stack().data().clone(), + memory: interpreter.memory.clone(), + instruction: Instruction(op), + push_bytes, + total_gas_used, + }); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + self.enter( + data.journaled_state.depth() as usize, + call.context.code_address, + call.context.scheme.into(), + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CallInputs, + gas: Gas, + status: Return, + retdata: Bytes, + _: bool, + ) -> (Return, Gas, Bytes) { + self.exit(); + + (status, gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + let nonce = data.journaled_state.account(call.caller).info.nonce; + self.enter( + data.journaled_state.depth() as usize, + get_create_address(call, nonce), + CallKind::Create, + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CreateInputs, + status: Return, + address: Option
, + gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + self.exit(); + + (status, address, gas, retdata) + } +} + +#[macro_export] +macro_rules! call_inspectors { + ($id:ident, [ $($inspector:expr),+ ], $call:block) => { + $({ + if let Some($id) = $inspector { + $call; + } + })+ + } +} + +struct InspectorData { + logs: Vec, + debug: Option, +} + +#[derive(Default)] +struct InspectorStack { + gas: Option>>, + logs: Option, + debugger: Option, +} + +impl InspectorStack { + fn collect_inspector_states(self) -> InspectorData { + InspectorData { + logs: self.logs.map(|logs| logs.logs).unwrap_or_default(), + debug: self.debugger.map(|debugger| debugger.arena), + } + } +} + +impl Inspector for InspectorStack { + fn initialize_interp( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.initialize_interp(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn log( + &mut self, + evm_data: &mut EVMData<'_, DB>, + address: &Address, + topics: &[H256], + data: &Bytes, + ) { + call_inspectors!(inspector, [&mut self.logs], { + inspector.log(evm_data, address, topics, data); + }); + } + + fn step_end( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + status: Return, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step_end(interpreter, data, is_static, status); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, gas, retdata) = inspector.call(data, call, is_static); + + if status != Return::Continue { + return (status, gas, retdata); + } + } + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CallInputs, + remaining_gas: Gas, + status: Return, + retdata: Bytes, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_gas, new_retdata) = inspector.call_end( + data, + call, + remaining_gas, + status, + retdata.clone(), + is_static, + ); + + if new_status != status || (new_status == Return::Revert && new_retdata != retdata) + { + return (new_status, new_gas, new_retdata); + } + } + ); + + (status, remaining_gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, addr, gas, retdata) = inspector.create(data, call); + + if status != Return::Continue { + return (status, addr, gas, retdata); + } + } + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CreateInputs, + status: Return, + address: Option
, + remaining_gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_address, new_gas, new_retdata) = inspector.create_end( + data, + call, + status, + address, + remaining_gas, + retdata.clone(), + ); + + if new_status != status { + return (new_status, new_address, new_gas, new_retdata); + } + } + ); + + (status, address, remaining_gas, retdata) + } + + fn selfdestruct(&mut self) { + call_inspectors!(inspector, [&mut self.logs, &mut self.debugger], { + Inspector::::selfdestruct(inspector); + }); + } +} + +pub struct RawCallResult { + pub exit_reason: Return, + pub reverted: bool, + pub result: Bytes, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub state_changeset: Option>, + pub env: Env, + pub out: TransactOut, +} + +#[derive(Clone, Debug)] +pub struct DeployResult { + pub exit_reason: Return, + pub reverted: bool, + pub address: Option
, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub env: Env, +} + +#[derive(Debug, Default)] +pub struct ExecutorBuilder { + debugger: bool, + gas_limit: Option, +} + +impl ExecutorBuilder { + pub fn set_debugger(mut self, enable: bool) -> Self { + self.debugger = enable; + self + } + + pub fn with_gas_limit(mut self, gas_limit: U256) -> Self { + self.gas_limit = Some(gas_limit); + self + } + + pub fn build(self) -> Executor { + Executor::new(self.debugger, self.gas_limit.unwrap_or(U256::MAX)) + } +} + +#[derive(Clone, Debug)] +pub struct Executor { + db: InMemoryDB, + debugger: bool, + gas_limit: U256, +} + +impl Executor { + fn new(debugger: bool, gas_limit: U256) -> Self { + Executor { + db: InMemoryDB::default(), + debugger, + gas_limit, + } + } + + pub fn db_mut(&mut self) -> &mut InMemoryDB { + &mut self.db + } + + pub fn deploy(&mut self, from: Address, code: Bytes, value: U256) -> DeployResult { + let env = self.build_test_env(from, TransactTo::Create(CreateScheme::Create), code, value); + let result = self.call_raw_with_env(env); + self.commit(&result); + + let RawCallResult { + exit_reason, + out, + gas_used, + gas_refunded, + logs, + debug, + env, + .. + } = result; + + let address = match (exit_reason, out) { + (return_ok!(), TransactOut::Create(_, Some(address))) => Some(address), + _ => None, + }; + + DeployResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + address, + gas_used, + gas_refunded, + logs, + debug, + env, + } + } + + pub fn call_raw( + &self, + from: Address, + to: Address, + calldata: Bytes, + value: U256, + ) -> RawCallResult { + let env = self.build_test_env(from, TransactTo::Call(to), calldata, value); + self.call_raw_with_env(env) + } + + fn call_raw_with_env(&self, mut env: Env) -> RawCallResult { + let mut inspector = self.inspector(); + let result = + evm_inner::<_, true>(&mut env, &mut self.db.clone(), &mut inspector).transact(); + let (exec_result, state_changeset) = result; + let ExecutionResult { + exit_reason, + gas_refunded, + gas_used, + out, + .. + } = exec_result; + + let result = match out { + TransactOut::Call(ref data) => data.to_owned(), + _ => Bytes::default(), + }; + let InspectorData { logs, debug } = inspector.collect_inspector_states(); + + RawCallResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + result, + gas_used, + gas_refunded, + logs: logs.to_vec(), + debug, + state_changeset: Some(state_changeset.into_iter().collect()), + env, + out, + } + } + + fn commit(&mut self, result: &RawCallResult) { + if let Some(state_changeset) = result.state_changeset.as_ref() { + self.db + .commit(state_changeset.clone().into_iter().collect()); + } + } + + fn inspector(&self) -> InspectorStack { + let mut stack = InspectorStack { + logs: Some(LogCollector::default()), + ..Default::default() + }; + if self.debugger { + let gas_inspector = Rc::new(RefCell::new(GasInspector::default())); + stack.gas = Some(gas_inspector.clone()); + stack.debugger = Some(Debugger::new(gas_inspector)); + } + stack + } + + fn build_test_env( + &self, + caller: Address, + transact_to: TransactTo, + data: Bytes, + value: U256, + ) -> Env { + Env { + block: BlockEnv { + gas_limit: self.gas_limit, + ..BlockEnv::default() + }, + tx: TxEnv { + caller, + transact_to, + data, + value, + gas_limit: self.gas_limit.as_u64(), + ..TxEnv::default() + }, + ..Env::default() + } + } +} diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 39e95c23..5b3361fa 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -23,7 +23,7 @@ mod util { Self: Sized, F: FnMut(B, V) -> B, { - self.into_iter().fold(Value::known(init), |acc, value| { + self.fold(Value::known(init), |acc, value| { acc.zip(value).map(|(acc, value)| f(acc, value)) }) } @@ -49,9 +49,7 @@ where self.transcript_initial_state .as_ref() .map(|transcript_initial_state| { - loader.assign_scalar(circuit::Value::known( - loader.scalar_chip().integer(*transcript_initial_state), - )) + loader.assign_scalar(circuit::Value::known(*transcript_initial_state)) }); Protocol { domain: self.domain.clone(), @@ -64,7 +62,7 @@ where quotient: self.quotient.clone(), transcript_initial_state, instance_committing_key: self.instance_committing_key.clone(), - linearization: self.linearization.clone(), + linearization: self.linearization, accumulator_indices: self.accumulator_indices.clone(), } } diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index d446c716..67c2b12d 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -45,15 +45,15 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.into_inner() } - pub fn ecc_chip(&self) -> Ref<'_, EccChip> { + pub fn ecc_chip(&self) -> Ref { self.ecc_chip.borrow() } - pub fn scalar_chip(&self) -> Ref<'_, EccChip::ScalarChip> { + pub fn scalar_chip(&self) -> Ref { Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) } - pub fn ctx(&self) -> Ref<'_, EccChip::Context> { + pub fn ctx(&self) -> Ref { self.ctx.borrow() } @@ -61,17 +61,15 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.borrow_mut() } - fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { - let assigned = self - .scalar_chip() + fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> EccChip::AssignedScalar { + self.scalar_chip() .assign_constant(&mut self.ctx_mut(), constant) - .unwrap(); - self.scalar_from_assigned(assigned) + .unwrap() } pub fn assign_scalar( self: &Rc, - scalar: circuit::Value, + scalar: circuit::Value, ) -> Scalar<'a, C, EccChip> { let assigned = self .scalar_chip() @@ -96,16 +94,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc Scalar { loader: self.clone(), index, - value, + value: value.into(), } } - fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { - self.ec_point_from_assigned( - self.ecc_chip() - .assign_constant(&mut self.ctx_mut(), constant) - .unwrap(), - ) + fn assign_const_ec_point(self: &Rc, constant: C) -> EccChip::AssignedEcPoint { + self.ecc_chip() + .assign_constant(&mut self.ctx_mut(), constant) + .unwrap() } pub fn assign_ec_point( @@ -135,7 +131,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc EcPoint { loader: self.clone(), index, - value, + value: value.into(), } } @@ -144,14 +140,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), (Value::Assigned(assigned), Value::Constant(constant)) | (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned.clone())], + &[(C::Scalar::one(), assigned)], *constant, ) .map(Value::Assigned) @@ -160,10 +156,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[ - (C::Scalar::one(), lhs.clone()), - (C::Scalar::one(), rhs.clone()), - ], + &[(C::Scalar::one(), lhs), (C::Scalar::one(), rhs)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -177,13 +170,13 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(-C::Scalar::one(), assigned.clone())], + &[(-C::Scalar::one(), assigned)], *constant, ) .map(Value::Assigned) @@ -192,7 +185,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned.clone())], + &[(C::Scalar::one(), assigned)], -*constant, ) .map(Value::Assigned) @@ -211,14 +204,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), (Value::Assigned(assigned), Value::Constant(constant)) | (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(*constant, assigned.clone())], + &[(*constant, assigned)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -227,7 +220,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_products_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), lhs.clone(), rhs.clone())], + &[(C::Scalar::one(), lhs, rhs)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -237,7 +230,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { - let output = match &scalar.value { + let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(constant.neg()), Value::Assigned(assigned) => { IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) @@ -249,7 +242,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { - let output = match &scalar.value { + let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), Value::Assigned(assigned) => Value::Assigned( IntegerInstructions::invert( @@ -295,11 +288,30 @@ pub enum Value { Assigned(L), } +impl Value { + fn maybe_const(&self) -> Option + where + T: Copy, + { + match self { + Value::Constant(constant) => Some(*constant), + _ => None, + } + } + + fn assigned(&self) -> &L { + match self { + Value::Assigned(assigned) => assigned, + _ => unreachable!(), + } + } +} + #[derive(Clone)] pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - value: Value, + value: RefCell>, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { @@ -307,11 +319,19 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> &self.loader } - pub fn assigned(&self) -> EccChip::AssignedScalar { - match &self.value { - Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), - Value::Assigned(assigned) => assigned.clone(), + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_scalar(constant)) } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() } } @@ -453,16 +473,31 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - value: Value, + value: RefCell>, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { - pub fn assigned(&self) -> EccChip::AssignedEcPoint { - match &self.value { - Value::Constant(constant) => self.loader.assign_const_ec_point(*constant).assigned(), - Value::Assigned(assigned) => assigned.clone(), + pub fn into_assigned(self) -> EccChip::AssignedEcPoint { + match self.value.into_inner() { + Value::Constant(constant) => self.loader.assign_const_ec_point(constant), + Value::Assigned(assigned) => assigned, } } + + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_ec_point(constant)) + } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() + } } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { @@ -558,35 +593,30 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader lhs: &EcPoint<'a, C, EccChip>, rhs: &EcPoint<'a, C, EccChip>, ) -> Result<(), crate::Error> { - match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => { - assert_eq!(lhs, rhs); - Ok(()) - } - (Value::Constant(constant), Value::Assigned(assigned)) - | (Value::Assigned(assigned), Value::Constant(constant)) => { - let constant = self.assign_const_ec_point(*constant).assigned(); - self.ecc_chip() - .assert_equal(&mut self.ctx_mut(), assigned, &constant) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => self - .ecc_chip() - .assert_equal(&mut self.ctx_mut(), lhs, rhs) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())), + if let (Value::Constant(lhs), Value::Constant(rhs)) = + (lhs.value().deref(), rhs.value().deref()) + { + assert_eq!(lhs, rhs); + Ok(()) + } else { + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), lhs.deref(), rhs.deref()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) } } fn multi_scalar_multiplication( pairs: &[( - >::LoadedScalar, - EcPoint<'a, C, EccChip>, + &>::LoadedScalar, + &EcPoint<'a, C, EccChip>, )], ) -> EcPoint<'a, C, EccChip> { let loader = &pairs[0].0.loader; let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = - pairs.iter().fold( + pairs.iter().cloned().fold( (C::identity(), Vec::new(), Vec::new(), Vec::new()), |( mut constant, @@ -594,22 +624,20 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader mut variable_base_non_scaled, mut variable_base_scaled, ), - (scalar, ec_point)| { - match (&ec_point.value, &scalar.value) { - (Value::Constant(ec_point), Value::Constant(scalar)) => { - constant = (*ec_point * scalar + constant).into() + (scalar, base)| { + match (scalar.value().deref(), base.value().deref()) { + (Value::Constant(scalar), Value::Constant(base)) => { + constant = (*base * scalar + constant).into() } - (Value::Constant(ec_point), Value::Assigned(scalar)) => { - fixed_base.push((scalar.clone(), *ec_point)) + (Value::Assigned(_), Value::Constant(base)) => { + fixed_base.push((scalar, *base)) } - (Value::Assigned(ec_point), Value::Constant(scalar)) + (Value::Constant(scalar), Value::Assigned(_)) if scalar.eq(&C::Scalar::one()) => { - variable_base_non_scaled.push(ec_point.clone()); - } - (Value::Assigned(ec_point), _) => { - variable_base_scaled.push((scalar.assigned(), ec_point.clone())) + variable_base_non_scaled.push(base); } + _ => variable_base_scaled.push((scalar, base)), }; ( constant, @@ -620,38 +648,47 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader }, ); - let fixed_base_msm = (!fixed_base.is_empty()).then(|| { - loader - .ecc_chip - .borrow_mut() - .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) - .unwrap() - }); - let variable_base_msm = (!variable_base_scaled.is_empty()).then(|| { - loader - .ecc_chip - .borrow_mut() - .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) - .unwrap() - }); + let fixed_base_msm = (!fixed_base.is_empty()) + .then(|| { + let fixed_base = fixed_base + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base)) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) + .unwrap() + }) + .map(RefCell::new); + let variable_base_msm = (!variable_base_scaled.is_empty()) + .then(|| { + let variable_base_scaled = variable_base_scaled + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base.assigned())) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) + .unwrap() + }) + .map(RefCell::new); let output = loader .ecc_chip() .sum_with_const( &mut loader.ctx_mut(), &variable_base_non_scaled .into_iter() - .chain(fixed_base_msm) - .chain(variable_base_msm) + .map(EcPoint::assigned) + .chain(fixed_base_msm.as_ref().map(RefCell::borrow)) + .chain(variable_base_msm.as_ref().map(RefCell::borrow)) .collect_vec(), constant, ) .unwrap(); - let normalized = loader - .ecc_chip() - .normalize(&mut loader.ctx_mut(), &output) - .unwrap(); - loader.ec_point_from_assigned(normalized) + loader.ec_point_from_assigned(output) } } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 97921d24..1d7b6258 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -3,7 +3,7 @@ use halo2_proofs::{ circuit::{Cell, Value}, plonk::Error, }; -use std::fmt::Debug; +use std::{fmt::Debug, ops::Deref}; pub trait Context: Debug { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; @@ -13,15 +13,13 @@ pub trait Context: Debug { pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { type Context: Context; - type Integer: Clone + Debug; + type AssignedCell: Clone + Debug; type AssignedInteger: Clone + Debug; - fn integer(&self, fe: F) -> Self::Integer; - fn assign_integer( &self, ctx: &mut Self::Context, - integer: Value, + integer: Value, ) -> Result; fn assign_constant( @@ -33,41 +31,45 @@ pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, Self::AssignedInteger)], + values: &[(F::Scalar, impl Deref)], constant: F::Scalar, ) -> Result; fn sum_products_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, Self::AssignedInteger, Self::AssignedInteger)], + values: &[( + F::Scalar, + impl Deref, + impl Deref, + )], constant: F::Scalar, ) -> Result; fn sub( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result; fn neg( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result; fn invert( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result; fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result<(), Error>; } @@ -77,57 +79,54 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { 'a, C::Scalar, Context = Self::Context, - Integer = Self::Scalar, + AssignedCell = Self::AssignedCell, AssignedInteger = Self::AssignedScalar, >; - type AssignedEcPoint: Clone + Debug; - type Scalar: Clone + Debug; + type AssignedCell: Clone + Debug; type AssignedScalar: Clone + Debug; + type AssignedEcPoint: Clone + Debug; fn scalar_chip(&self) -> &Self::ScalarChip; fn assign_constant( &self, ctx: &mut Self::Context, - point: C, + ec_point: C, ) -> Result; fn assign_point( &self, ctx: &mut Self::Context, - point: Value, + ec_point: Value, ) -> Result; fn sum_with_const( &self, ctx: &mut Self::Context, - values: &[Self::AssignedEcPoint], + values: &[impl Deref], constant: C, ) -> Result; fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, C)], + pairs: &[(impl Deref, C)], ) -> Result; fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], - ) -> Result; - - fn normalize( - &self, - ctx: &mut Self::Context, - point: &Self::AssignedEcPoint, + pairs: &[( + impl Deref, + impl Deref, + )], ) -> Result; fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedEcPoint, - b: &Self::AssignedEcPoint, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, ) -> Result<(), Error>; } @@ -152,7 +151,7 @@ mod halo2_wrong { AssignedPoint, BaseFieldEccChip, }; use rand::rngs::OsRng; - use std::iter; + use std::{iter, ops::Deref}; impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { @@ -166,17 +165,13 @@ mod halo2_wrong { impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { type Context = RegionCtx<'a, F>; - type Integer = F; + type AssignedCell = AssignedCell; type AssignedInteger = AssignedCell; - fn integer(&self, scalar: F) -> Self::Integer { - scalar - } - fn assign_integer( &self, ctx: &mut Self::Context, - integer: Value, + integer: Value, ) -> Result { self.assign_value(ctx, integer) } @@ -192,7 +187,7 @@ mod halo2_wrong { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F, Self::AssignedInteger)], + values: &[(F, impl Deref)], constant: F, ) -> Result { self.compose( @@ -208,7 +203,11 @@ mod halo2_wrong { fn sum_products_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F, Self::AssignedInteger, Self::AssignedInteger)], + values: &[( + F, + impl Deref, + impl Deref, + )], constant: F, ) -> Result { match values.len() { @@ -289,39 +288,39 @@ mod halo2_wrong { fn sub( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::sub(self, ctx, a, b) + MainGateInstructions::sub(self, ctx, lhs, rhs) } fn neg( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::neg_with_constant(self, ctx, a, F::zero()) + MainGateInstructions::neg_with_constant(self, ctx, value, F::zero()) } fn invert( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::invert_unsafe(self, ctx, a) + MainGateInstructions::invert_unsafe(self, ctx, value) } fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result<(), Error> { let mut eq = true; - a.value().zip(b.value()).map(|(lhs, rhs)| { + lhs.value().zip(rhs.value()).map(|(lhs, rhs)| { eq &= lhs == rhs; }); - MainGateInstructions::assert_equal(self, ctx, a, b) + MainGateInstructions::assert_equal(self, ctx, lhs, rhs) .and(eq.then_some(()).ok_or(Error::Synthesis)) } } @@ -331,9 +330,9 @@ mod halo2_wrong { { type Context = RegionCtx<'a, C::Scalar>; type ScalarChip = MainGate; - type AssignedEcPoint = AssignedPoint; - type Scalar = C::Scalar; + type AssignedCell = AssignedCell; type AssignedScalar = AssignedCell; + type AssignedEcPoint = AssignedPoint; fn scalar_chip(&self) -> &Self::ScalarChip { self.main_gate() @@ -342,65 +341,79 @@ mod halo2_wrong { fn assign_constant( &self, ctx: &mut Self::Context, - point: C, + ec_point: C, ) -> Result { - self.assign_constant(ctx, point) + self.assign_constant(ctx, ec_point) } fn assign_point( &self, ctx: &mut Self::Context, - point: Value, + ec_point: Value, ) -> Result { - self.assign_point(ctx, point) + self.assign_point(ctx, ec_point) } fn sum_with_const( &self, ctx: &mut Self::Context, - values: &[Self::AssignedEcPoint], + values: &[impl Deref], constant: C, ) -> Result { if values.is_empty() { return self.assign_constant(ctx, constant); } - iter::empty() - .chain( - (!bool::from(constant.is_identity())) - .then(|| self.assign_constant(ctx, constant)), - ) - .chain(values.iter().cloned().map(Ok)) + let constant = (!bool::from(constant.is_identity())) + .then(|| self.assign_constant(ctx, constant)) + .transpose()?; + let output = iter::empty() + .chain(constant) + .chain(values.iter().map(|value| value.deref().clone())) + .map(Ok) .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) - .unwrap() + .unwrap()?; + self.normalize(ctx, &output) } fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, C)], + pairs: &[(impl Deref, C)], ) -> Result { + assert!(!pairs.is_empty()); + // FIXME: Implement fixed base MSM in halo2_wrong let pairs = pairs .iter() + .filter(|(_, base)| !bool::from(base.is_identity())) .map(|(scalar, base)| { - Ok::<_, Error>((scalar.clone(), self.assign_constant(ctx, *base)?)) + Ok::<_, Error>((scalar.deref().clone(), self.assign_constant(ctx, *base)?)) }) .collect::, _>>()?; + let pairs = pairs + .iter() + .map(|(scalar, base)| (scalar, base)) + .collect_vec(); self.variable_base_msm(ctx, &pairs) } fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], + pairs: &[( + impl Deref, + impl Deref, + )], ) -> Result { + assert!(!pairs.is_empty()); + const WINDOW_SIZE: usize = 3; let pairs = pairs .iter() - .map(|(scalar, base)| (base.clone(), scalar.clone())) + .map(|(scalar, base)| (base.deref().clone(), scalar.deref().clone())) .collect_vec(); - match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { + let output = match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { Err(_) => { if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { let aux_generator = Value::known(C::Curve::random(OsRng).into()); @@ -410,30 +423,23 @@ mod halo2_wrong { self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) } result => result, - } - } - - fn normalize( - &self, - ctx: &mut Self::Context, - point: &Self::AssignedEcPoint, - ) -> Result { - self.normalize(ctx, point) + }?; + self.normalize(ctx, &output) } fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedEcPoint, - b: &Self::AssignedEcPoint, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, ) -> Result<(), Error> { let mut eq = true; - [(a.x(), b.x()), (a.y(), b.y())].map(|(lhs, rhs)| { + [(lhs.x(), rhs.x()), (lhs.y(), rhs.y())].map(|(lhs, rhs)| { lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { eq &= lhs.value() == rhs.value(); }); }); - self.assert_equal(ctx, a, b) + self.assert_equal(ctx, lhs, rhs) .and(eq.then_some(()).ok_or(Error::Synthesis)) } } diff --git a/src/loader/native.rs b/src/loader/native.rs index 1451ff76..6fce383a 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -54,10 +54,11 @@ impl EcPointLoader for NativeLoader { } fn multi_scalar_multiplication( - pairs: &[(>::LoadedScalar, C)], + pairs: &[(&>::LoadedScalar, &C)], ) -> C { pairs .iter() + .cloned() .map(|(scalar, base)| *base * scalar) .reduce(|acc, value| acc + value) .unwrap() diff --git a/src/pcs.rs b/src/pcs.rs index 65804895..23fdda2a 100644 --- a/src/pcs.rs +++ b/src/pcs.rs @@ -123,7 +123,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(repr: Vec) -> Result; + fn from_repr(repr: &[&L::LoadedScalar]) -> Result; } impl AccumulatorEncoding for () @@ -132,7 +132,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(_: Vec) -> Result { + fn from_repr(_: &[&L::LoadedScalar]) -> Result { unimplemented!() } } diff --git a/src/pcs/kzg.rs b/src/pcs/kzg.rs index 9f10bd44..056589a8 100644 --- a/src/pcs/kzg.rs +++ b/src/pcs/kzg.rs @@ -15,6 +15,9 @@ pub use accumulator::{KzgAccumulator, LimbsEncoding}; pub use decider::KzgDecidingKey; pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; +#[cfg(feature = "loader_halo2")] +pub use accumulator::LimbsEncodingInstructions; + #[derive(Clone, Debug)] pub struct Kzg(PhantomData<(M, MOS)>); diff --git a/src/pcs/kzg/accumulation.rs b/src/pcs/kzg/accumulation.rs index cd13fd00..4273ce9a 100644 --- a/src/pcs/kzg/accumulation.rs +++ b/src/pcs/kzg/accumulation.rs @@ -44,16 +44,16 @@ where ) -> Result { let (lhs, rhs) = instances .iter() - .cloned() - .map(|accumulator| (accumulator.lhs, accumulator.rhs)) - .chain(proof.blind.clone()) + .map(|accumulator| (&accumulator.lhs, &accumulator.rhs)) + .chain(proof.blind.as_ref().map(|(lhs, rhs)| (lhs, rhs))) .unzip::<_, _, Vec<_>, Vec<_>>(); let powers_of_r = proof.r.powers(lhs.len()); - let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + let [lhs, rhs] = [lhs, rhs].map(|bases| { + bases + .into_iter() .zip(powers_of_r.iter()) - .map(|(msm, r)| Msm::::base(msm) * r) + .map(|(base, r)| Msm::::base(base) * r) .sum::>() .evaluate(None) }); @@ -184,7 +184,7 @@ where let powers_of_r = r.powers(lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_r.iter()) .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) .sum::>() diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs index 17c7bf83..50ce6ba8 100644 --- a/src/pcs/kzg/accumulator.rs +++ b/src/pcs/kzg/accumulator.rs @@ -53,13 +53,22 @@ mod native { Accumulator = KzgAccumulator, >, { - fn from_repr(limbs: Vec) -> Result { + fn from_repr(limbs: &[&C::Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs .chunks(LIMBS) .into_iter() - .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + .map(|limbs| { + fe_from_limbs::<_, _, LIMBS, BITS>( + limbs + .iter() + .map(|limb| **limb) + .collect_vec() + .try_into() + .unwrap(), + ) + }) .collect_vec() .try_into() .unwrap(); @@ -100,7 +109,7 @@ mod evm { Accumulator = KzgAccumulator>, >, { - fn from_repr(limbs: Vec) -> Result { + fn from_repr(limbs: &[&Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); @@ -122,10 +131,13 @@ mod evm { } } +#[cfg(feature = "loader_halo2")] +pub use halo2::LimbsEncodingInstructions; + #[cfg(feature = "loader_halo2")] mod halo2 { use crate::{ - loader::halo2::{Context, EccInstructions, Halo2Loader, Scalar, Valuetools}, + loader::halo2::{EccInstructions, Halo2Loader, Scalar, Valuetools}, pcs::{ kzg::{KzgAccumulator, LimbsEncoding}, AccumulatorEncoding, PolynomialCommitmentScheme, @@ -136,19 +148,18 @@ mod halo2 { }, Error, }; - use halo2_proofs::circuit::Value; - use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; - use std::{iter, rc::Rc}; + use halo2_proofs::{circuit::Value, plonk}; + use std::{iter, ops::Deref, rc::Rc}; - fn ec_point_from_assigned_limbs( - limbs: &[AssignedValue], + fn ec_point_from_limbs( + limbs: &[Value<&C::Scalar>], ) -> Value { assert_eq!(limbs.len(), 2 * LIMBS); let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { limbs .iter() - .map(|assigned| assigned.value()) + .cloned() .fold_zipped(Vec::new(), |mut acc, limb| { acc.push(*limb); acc @@ -159,6 +170,22 @@ mod halo2 { x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) } + pub trait LimbsEncodingInstructions<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize>: + EccInstructions<'a, C> + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result; + + fn assign_ec_point_to_limbs( + &self, + ctx: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error>; + } + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> AccumulatorEncoding>, PCS> for LimbsEncoding where @@ -168,41 +195,72 @@ mod halo2 { Rc>, Accumulator = KzgAccumulator>>, >, - EccChip: EccInstructions< - 'a, - C, - AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, - AssignedScalar = AssignedValue, - >, + EccChip: LimbsEncodingInstructions<'a, C, LIMBS, BITS>, { - fn from_repr(limbs: Vec>) -> Result { + fn from_repr(limbs: &[&Scalar<'a, C, EccChip>]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); - let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); - let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( - |assigned_limbs| { - let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>(assigned_limbs); - loader.assign_ec_point(ec_point) - }, - ); - - for (src, dst) in assigned_limbs.iter().zip( - iter::empty() - .chain(lhs.assigned().x().limbs()) - .chain(lhs.assigned().y().limbs()) - .chain(rhs.assigned().x().limbs()) - .chain(rhs.assigned().y().limbs()), - ) { - loader - .ctx_mut() - .constrain_equal(src.cell(), dst.as_ref().cell()) + let [lhs, rhs] = [&limbs[..2 * LIMBS], &limbs[2 * LIMBS..]].map(|limbs| { + let assigned = loader + .ecc_chip() + .assign_ec_point_from_limbs( + &mut loader.ctx_mut(), + &limbs.iter().map(|limb| limb.assigned()).collect_vec(), + ) .unwrap(); + loader.ec_point_from_assigned(assigned) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } + } + + mod halo2_wrong { + use super::*; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> + LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result { + assert_eq!(limbs.len(), 2 * LIMBS); + + let ec_point = self.assign_point( + ctx, + ec_point_from_limbs::<_, LIMBS, BITS>( + &limbs.iter().map(|limb| limb.value()).collect_vec(), + ), + )?; + + for (src, dst) in limbs.iter().zip_eq( + iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()), + ) { + ctx.constrain_equal(src.cell(), dst.as_ref().cell())?; + } + + Ok(ec_point) } - let accumulator = KzgAccumulator::new(lhs, rhs); - Ok(accumulator) + fn assign_ec_point_to_limbs( + &self, + _: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error> { + Ok(iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()) + .map(|limb| limb.as_ref()) + .cloned() + .collect()) + } } } } diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs index b6957883..de3e2a06 100644 --- a/src/pcs/kzg/decider.rs +++ b/src/pcs/kzg/decider.rs @@ -144,7 +144,7 @@ mod evm { let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_challenge.iter()) .map(|(msm, power_of_challenge)| { Msm::>::base(msm) * power_of_challenge diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs index 7f70ca61..f542f750 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -47,8 +47,8 @@ where queries: &[Query], proof: &Bdfg21Proof, ) -> Result { + let sets = query_sets(queries); let f = { - let sets = query_sets(queries); let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); let powers_of_mu = proof @@ -62,10 +62,10 @@ where msms.zip(proof.gamma.powers(sets.len()).into_iter()) .map(|(msm, power_of_gamma)| msm * &power_of_gamma) .sum::>() - - Msm::base(proof.w.clone()) * &coeffs[0].z_s + - Msm::base(&proof.w) * &coeffs[0].z_s }; - let rhs = Msm::base(proof.w_prime.clone()); + let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; Ok(KzgAccumulator::new( @@ -143,7 +143,7 @@ fn query_sets(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec>( - sets: &[QuerySet], +fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( + sets: &[QuerySet<'a, F, T>], z: &T, z_prime: &T, ) -> Vec> { @@ -211,17 +211,17 @@ fn query_set_coeffs>( } #[derive(Clone, Debug)] -struct QuerySet { +struct QuerySet<'a, F, T> { shifts: Vec, polys: Vec, - evals: Vec>, + evals: Vec>, } -impl> QuerySet { +impl<'a, F: FieldExt, T: LoadedScalar> QuerySet<'a, F, T> { fn msm>( &self, coeff: &QuerySetCoeff, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_mu: &[T], ) -> Msm { self.polys @@ -241,7 +241,7 @@ impl> QuerySet { &coeff .eval_coeffs .iter() - .zip(evals.iter()) + .zip(evals.iter().cloned()) .map(|(coeff, eval)| (coeff.evaluated(), eval)) .collect_vec(), ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated(); @@ -288,18 +288,15 @@ where }) .collect_vec(); - let z = &powers_of_z[1].clone(); + let z = &powers_of_z[1]; let z_pow_k_minus_one = { let k_minus_one = shifts.len() - 1; powers_of_z .iter() .enumerate() .skip(1) - .filter_map(|(i, power_of_z)| { - (k_minus_one & (1 << i) == 1).then(|| power_of_z.clone()) - }) - .reduce(|acc, value| acc * value) - .unwrap_or_else(|| loader.load_one()) + .filter_map(|(i, power_of_z)| (k_minus_one & (1 << i) == 1).then(|| power_of_z)) + .fold(loader.load_one(), |acc, value| acc * value) }; let barycentric_weights = shifts @@ -354,7 +351,7 @@ where .map(Fraction::evaluated) .collect_vec(), ); - self.r_eval_coeff = Some(match self.commitment_coeff.clone() { + self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), }); diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/src/pcs/kzg/multiopen/gwc19.rs index 121fce8a..6e3f579f 100644 --- a/src/pcs/kzg/multiopen/gwc19.rs +++ b/src/pcs/kzg/multiopen/gwc19.rs @@ -55,15 +55,13 @@ where .map(|(msm, power_of_u)| msm * power_of_u) .sum::>() }; - let z_omegas = sets - .iter() - .map(|set| z.clone() * &z.loader().load_const(&set.shift)); + let z_omegas = sets.iter().map(|set| z.loader().load_const(&set.shift) * z); let rhs = proof .ws .iter() .zip(powers_of_u.iter()) - .map(|(w, power_of_u)| Msm::base(w.clone()) * power_of_u) + .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); let lhs = f + rhs .iter() @@ -105,25 +103,25 @@ where } } -struct QuerySet { +struct QuerySet<'a, F, T> { shift: F, polys: Vec, - evals: Vec, + evals: Vec<&'a T>, } -impl QuerySet +impl<'a, F, T> QuerySet<'a, F, T> where F: PrimeField, T: Clone, { fn msm>( &self, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_v: &[L::LoadedScalar], ) -> Msm { self.polys .iter() - .zip(self.evals.iter()) + .zip(self.evals.iter().cloned()) .map(|(poly, eval)| { let commitment = commitments[*poly].clone(); commitment - Msm::constant(eval.clone()) @@ -142,12 +140,12 @@ where queries.iter().fold(Vec::new(), |mut sets, query| { if let Some(pos) = sets.iter().position(|set| set.shift == query.shift) { sets[pos].polys.push(query.poly); - sets[pos].evals.push(query.eval.clone()); + sets[pos].evals.push(&query.eval); } else { sets.push(QuerySet { shift: query.shift, polys: vec![query.poly], - evals: vec![query.eval.clone()], + evals: vec![&query.eval], }); } sets diff --git a/src/system/halo2.rs b/src/system/halo2.rs index 3d7cf140..a49fa6ef 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -601,11 +601,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .zip(permutation_fixeds.chunks(self.permutation_chunk_size)) .enumerate() .map( - |(i, ((((z, z_w, _), (_, z_next_w, _)), polys), permutation_fixeds))| { + |( + i, + ((((z, z_omega, _), (_, z_next_omega, _)), polys), permutation_fixeds), + )| { let left = if self.zk || zs.len() == 1 { - z_w.clone() + z_omega.clone() } else { - z_w + l_last * (z_next_w - z_w) + z_omega + l_last * (z_next_omega - z_omega) } * polys .iter() .zip(permutation_fixeds.iter()) @@ -675,7 +678,10 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .iter() .zip(polys.iter()) .flat_map( - |(lookup, (z, z_w, permuted_input, permuted_input_w_inv, permuted_table))| { + |( + lookup, + (z, z_omega, permuted_input, permuted_input_omega_inv, permuted_table), + )| { let input = compress(lookup.input_expressions()); let table = compress(lookup.table_expressions()); iter::empty() @@ -683,20 +689,20 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(self.zk.then(|| l_last * (z * z - z))) .chain(Some(if self.zk { l_active - * (z_w * (permuted_input + beta) * (permuted_table + gamma) + * (z_omega * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma)) } else { - z_w * (permuted_input + beta) * (permuted_table + gamma) + z_omega * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma) })) .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) .chain(Some(if self.zk { l_active * (permuted_input - permuted_table) - * (permuted_input - permuted_input_w_inv) + * (permuted_input - permuted_input_omega_inv) } else { (permuted_input - permuted_table) - * (permuted_input - permuted_input_w_inv) + * (permuted_input - permuted_input_omega_inv) })) }, ) diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs index 3b7e1694..2b267758 100644 --- a/src/system/halo2/test/kzg.rs +++ b/src/system/halo2/test/kzg.rs @@ -27,8 +27,8 @@ pub fn main_gate_with_range_with_mock_kzg_accumulator( let srs = read_or_create_srs(TESTDATA_DIR, 1, setup::); let [g1, s_g1] = [srs.get_g()[0], srs.get_g()[1]].map(|point| point.coordinates().unwrap()); MainGateWithRange::new( - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() + [s_g1.x(), s_g1.y(), g1.x(), g1.y()] + .into_iter() .cloned() .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) .collect(), diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index 18767c98..11a6046f 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -7,7 +7,7 @@ use crate::{ pcs::{ kzg::{ Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, - KzgSuccinctVerifyingKey, LimbsEncoding, + KzgSuccinctVerifyingKey, LimbsEncoding, LimbsEncodingInstructions, }, AccumulationScheme, AccumulationSchemeProver, }, @@ -31,7 +31,7 @@ use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{floor_planner::V1, Layouter, Value}, plonk, - plonk::Circuit, + plonk::{Circuit, Error}, poly::{ commitment::ParamsProver, kzg::{ @@ -48,7 +48,7 @@ use halo2_wrong_ecc::{ }; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{iter, rc::Rc}; +use std::rc::Rc; const T: usize = 5; const RATE: usize = 4; @@ -281,14 +281,14 @@ impl Circuit for Accumulation { range_chip.load_table(&mut layouter)?; - let (lhs, rhs) = layouter.assign_region( + let accumulator_limbs = layouter.assign_region( || "", |region| { let ctx = RegionCtx::new(region, 0); let ecc_chip = config.ecc_chip(); let loader = Halo2Loader::new(ecc_chip, ctx); - let KzgAccumulator { lhs, rhs } = accumulate( + let accumulator = accumulate( &self.svk, &loader, &self.snarks, @@ -296,21 +296,26 @@ impl Circuit for Accumulation { self.as_proof(), ); + let accumulator_limbs = [accumulator.lhs, accumulator.rhs] + .iter() + .map(|ec_point| { + loader + .ecc_chip() + .assign_ec_point_to_limbs(&mut loader.ctx_mut(), ec_point.assigned()) + }) + .collect::, Error>>()? + .into_iter() + .flatten(); + loader.print_row_metering(); println!("Total row cost: {}", loader.ctx().offset()); - Ok((lhs.assigned(), rhs.assigned())) + Ok(accumulator_limbs) }, )?; - for (limb, row) in iter::empty() - .chain(lhs.x().limbs()) - .chain(lhs.y().limbs()) - .chain(rhs.x().limbs()) - .chain(rhs.y().limbs()) - .zip(0..) - { - main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + for (row, limb) in accumulator_limbs.enumerate() { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; } Ok(()) diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index bb7e8e23..ab7d548c 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -1,60 +1,58 @@ use crate::{ loader::{ - halo2::{EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + halo2::{EcPoint, EccInstructions, Halo2Loader, Scalar}, native::{self, NativeLoader}, Loader, ScalarLoader, }, util::{ - arithmetic::{fe_to_fe, CurveAffine, FieldExt, PrimeField}, + arithmetic::{fe_to_fe, CurveAffine, PrimeField}, hash::Poseidon, transcript::{Transcript, TranscriptRead, TranscriptWrite}, Itertools, }, Error, }; -use halo2_proofs::{ - circuit::{AssignedCell, Value}, - transcript::EncodedChallenge, -}; +use halo2_proofs::{circuit::Value, transcript::EncodedChallenge}; use std::{ io::{self, Read, Write}, - marker::PhantomData, rc::Rc, }; -pub trait EncodeNative<'a, C: CurveAffine, N: FieldExt>: EccInstructions<'a, C> { - fn encode_native( +/// Encoding that encodes elliptic curve point into native field elements. +pub trait NativeEncoding<'a, C>: EccInstructions<'a, C> +where + C: CurveAffine, +{ + fn encode( &self, ctx: &mut Self::Context, ec_point: &Self::AssignedEcPoint, - ) -> Result>, Error>; + ) -> Result, Error>; } pub struct PoseidonTranscript< - C: CurveAffine, - L: Loader, + C, + L, S, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize, -> { +> where + C: CurveAffine, + L: Loader, +{ loader: L, stream: S, buf: Poseidon>::LoadedScalar, T, RATE>, - _marker: PhantomData, } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > PoseidonTranscript>, Value, T, RATE, R_F, R_P> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { pub fn new(loader: &Rc>, stream: Value) -> Self { let buf = Poseidon::new(loader, R_F, R_P); @@ -62,22 +60,17 @@ impl< loader: loader.clone(), stream, buf, - _marker: PhantomData, } } } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript>> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + Transcript>> for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { fn loader(&self) -> &Rc> { &self.loader @@ -96,30 +89,31 @@ impl< let encoded = self .loader .ecc_chip() - .encode_native(&mut self.loader.ctx_mut(), &ec_point.assigned()) + .encode(&mut self.loader.ctx_mut(), &ec_point.assigned()) .map(|encoded| { encoded .into_iter() .map(|encoded| self.loader.scalar_from_assigned(encoded)) .collect_vec() }) - .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; + .map_err(|_| { + Error::Transcript( + io::ErrorKind::Other, + "Failed to encode elliptic curve point into native field elements".to_string(), + ) + })?; self.buf.update(&encoded); Ok(()) } } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead>> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + TranscriptRead>> for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { fn read_scalar(&mut self) -> Result, Error> { let scalar = self.stream.as_mut().and_then(|stream| { @@ -128,7 +122,7 @@ impl< return Value::unknown(); } Option::::from(C::Scalar::from_repr(data)) - .map(|scalar| Value::known(self.loader.scalar_chip().integer(scalar))) + .map(Value::known) .unwrap_or_else(Value::unknown) }); let scalar = self.loader.assign_scalar(scalar); @@ -152,8 +146,6 @@ impl< } } -// - impl PoseidonTranscript { @@ -162,7 +154,6 @@ impl TranscriptRead - for PoseidonTranscript +impl + TranscriptRead for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn read_scalar(&mut self) -> Result { let mut data = ::Repr::default(); @@ -243,14 +230,11 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > PoseidonTranscript +impl + PoseidonTranscript +where + C: CurveAffine, + W: Write, { pub fn stream_mut(&mut self) -> &mut W { &mut self.stream @@ -261,14 +245,11 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWrite for PoseidonTranscript +impl TranscriptWrite + for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { self.common_scalar(&scalar)?; @@ -332,15 +313,12 @@ impl halo2_proofs::transcript::TranscriptRead> +impl + halo2_proofs::transcript::TranscriptRead> for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn read_point(&mut self) -> io::Result { match TranscriptRead::read_ec_point(self) { @@ -359,30 +337,24 @@ impl< } } -impl< - C: CurveAffine, - R: Read, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptReadBuffer> +impl + halo2_proofs::transcript::TranscriptReadBuffer> for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn init(reader: R) -> Self { Self::new(reader) } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWrite> +impl + halo2_proofs::transcript::TranscriptWrite> for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn write_point(&mut self, ec_point: C) -> io::Result<()> { halo2_proofs::transcript::Transcript::>::common_point( @@ -399,15 +371,12 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWriterBuffer> +impl + halo2_proofs::transcript::TranscriptWriterBuffer> for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn init(writer: W) -> Self { Self::new(writer) @@ -419,15 +388,15 @@ impl< } mod halo2_wrong { - use crate::system::halo2::transcript::halo2::EncodeNative; + use crate::system::halo2::transcript::halo2::NativeEncoding; use halo2_curves::CurveAffine; use halo2_proofs::circuit::AssignedCell; use halo2_wrong_ecc::BaseFieldEccChip; - impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EncodeNative<'a, C, C::Scalar> + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> NativeEncoding<'a, C> for BaseFieldEccChip { - fn encode_native( + fn encode( &self, _: &mut Self::Context, ec_point: &Self::AssignedEcPoint, diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index b9a5c7c6..02d6da32 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -186,14 +186,14 @@ impl Fraction { self.eval = Some( self.numer - .as_ref() - .map(|numer| numer.clone() * &self.denom) + .take() + .map(|numer| numer * &self.denom) .unwrap_or_else(|| self.denom.clone()), ); } pub fn evaluated(&self) -> &T { - assert!(self.inv); + assert!(self.eval.is_some()); self.eval.as_ref().unwrap() } @@ -241,19 +241,12 @@ pub fn fe_to_limbs [F2; LIMBS] { let big = BigUint::from_bytes_le(fe.to_repr().as_ref()); - let mask = (BigUint::one() << BITS) - 1usize; + let mask = &((BigUint::one() << BITS) - 1usize); (0usize..) .step_by(BITS) .take(LIMBS) - .map(move |shift| fe_from_big((&big >> shift) & &mask)) + .map(|shift| fe_from_big((&big >> shift) & mask)) .collect_vec() .try_into() .unwrap() } - -pub fn powers(scalar: F) -> impl Iterator -where - for<'a> F: Mul<&'a F, Output = F> + One + Clone, -{ - iter::successors(Some(F::one()), move |power| Some(scalar.clone() * power)) -} diff --git a/src/util/msm.rs b/src/util/msm.rs index a15eab9d..f35ea197 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -9,13 +9,13 @@ use std::{ }; #[derive(Clone, Debug)] -pub struct Msm> { +pub struct Msm<'a, C: CurveAffine, L: Loader> { constant: Option, scalars: Vec, - bases: Vec, + bases: Vec<&'a L::LoadedEcPoint>, } -impl Default for Msm +impl<'a, C, L> Default for Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -29,7 +29,7 @@ where } } -impl Msm +impl<'a, C, L> Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -41,7 +41,7 @@ where } } - pub fn base(base: L::LoadedEcPoint) -> Self { + pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self { let one = base.loader().load_one(); Msm { scalars: vec![one], @@ -72,8 +72,12 @@ where .ec_point_load_const(&gen) }); let pairs = iter::empty() - .chain(self.constant.map(|constant| (constant, gen.unwrap()))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())) + .chain( + self.constant + .as_ref() + .map(|constant| (constant, gen.as_ref().unwrap())), + ) + .chain(self.scalars.iter().zip(self.bases.into_iter())) .collect_vec(); L::multi_scalar_multiplication(&pairs) } @@ -87,16 +91,16 @@ where } } - pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { + pub fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) { if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { - self.scalars[pos] += scalar; + self.scalars[pos] += &scalar; } else { self.scalars.push(scalar); self.bases.push(base); } } - pub fn extend(&mut self, mut other: Self) { + pub fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) { match (self.constant.as_mut(), other.constant.as_ref()) { (Some(lhs), Some(rhs)) => *lhs += rhs, (None, Some(_)) => self.constant = other.constant.take(), @@ -108,58 +112,62 @@ where } } -impl Add> for Msm +impl<'a, 'b, C, L> Add> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; - fn add(mut self, rhs: Msm) -> Self::Output { + fn add(mut self, rhs: Msm<'b, C, L>) -> Self::Output { self.extend(rhs); self } } -impl AddAssign> for Msm +impl<'a, 'b, C, L> AddAssign> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - fn add_assign(&mut self, rhs: Msm) { + fn add_assign(&mut self, rhs: Msm<'b, C, L>) { self.extend(rhs); } } -impl Sub> for Msm +impl<'a, 'b, C, L> Sub> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; - fn sub(mut self, rhs: Msm) -> Self::Output { + fn sub(mut self, rhs: Msm<'b, C, L>) -> Self::Output { self.extend(-rhs); self } } -impl SubAssign> for Msm +impl<'a, 'b, C, L> SubAssign> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - fn sub_assign(&mut self, rhs: Msm) { + fn sub_assign(&mut self, rhs: Msm<'b, C, L>) { self.extend(-rhs); } } -impl Mul<&L::LoadedScalar> for Msm +impl<'a, C, L> Mul<&L::LoadedScalar> for Msm<'a, C, L> where C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { self.scale(rhs); @@ -167,7 +175,7 @@ where } } -impl MulAssign<&L::LoadedScalar> for Msm +impl<'a, C, L> MulAssign<&L::LoadedScalar> for Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -177,13 +185,13 @@ where } } -impl Neg for Msm +impl<'a, C, L> Neg for Msm<'a, C, L> where C: CurveAffine, L: Loader, { - type Output = Msm; - fn neg(mut self) -> Msm { + type Output = Msm<'a, C, L>; + fn neg(mut self) -> Msm<'a, C, L> { self.constant = self.constant.map(|constant| -constant); for scalar in self.scalars.iter_mut() { *scalar = -scalar.clone(); @@ -192,7 +200,7 @@ where } } -impl Sum for Msm +impl<'a, C, L> Sum for Msm<'a, C, L> where C: CurveAffine, L: Loader, diff --git a/src/util/protocol.rs b/src/util/protocol.rs index e9ecd9f0..f7747060 100644 --- a/src/util/protocol.rs +++ b/src/util/protocol.rs @@ -41,7 +41,7 @@ where quotient: self.quotient.clone(), transcript_initial_state, instance_committing_key: self.instance_committing_key.clone(), - linearization: self.linearization.clone(), + linearization: self.linearization, accumulator_indices: self.accumulator_indices.clone(), } } @@ -82,11 +82,11 @@ where let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); - let zn_minus_one = zn.clone() - one; + let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); let n_inv = loader.load_const(&domain.n_inv); - let numer = zn_minus_one.clone() * n_inv; + let numer = zn_minus_one.clone() * &n_inv; let omegas = langranges .iter() .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) @@ -378,7 +378,7 @@ fn merge_left_right(a: Option>, b: Option>) -> O } } -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub enum LinearizationStrategy { /// Older linearization strategy of GWC19, which has linearization /// polynomial that doesn't evaluate to 0, and requires prover to send extra diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index c3276af4..df208949 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -137,8 +137,8 @@ where instances .iter() .zip(bases.iter()) - .map(|(scalar, base)| Msm::::base(base.clone()) * scalar) - .chain(constant.clone().map(|constant| Msm::base(constant))) + .map(|(scalar, base)| Msm::::base(base) * scalar) + .chain(constant.as_ref().map(Msm::base)) .sum::>() .evaluate(None) }) @@ -190,12 +190,13 @@ where .accumulator_indices .iter() .map(|accumulator_indices| { - accumulator_indices - .iter() - .map(|&(i, j)| instances[i][j].clone()) - .collect() + AE::from_repr( + &accumulator_indices + .iter() + .map(|&(i, j)| &instances[i][j]) + .collect_vec(), + ) }) - .map(AE::from_repr) .collect::, _>>()?; Ok(Self { @@ -241,25 +242,20 @@ where .collect() } - fn commitments( - &self, - protocol: &Protocol, + fn commitments<'a>( + &'a self, + protocol: &'a Protocol, common_poly_eval: &CommonPolynomialEvaluation, evaluations: &mut HashMap, ) -> Result>, Error> { let loader = common_poly_eval.zn().loader(); let mut commitments = iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| Msm::base(value.clone())), - ) + .chain(protocol.preprocessed.iter().map(Msm::base)) .chain( self.committed_instances - .clone() + .as_ref() .map(|committed_instances| { - committed_instances.into_iter().map(Msm::base).collect_vec() + committed_instances.iter().map(Msm::base).collect_vec() }) .unwrap_or_else(|| { iter::repeat_with(Default::default) @@ -267,7 +263,7 @@ where .collect_vec() }), ) - .chain(self.witnesses.iter().cloned().map(Msm::base)) + .chain(self.witnesses.iter().map(Msm::base)) .collect_vec(); let numerator = protocol.quotient.numerator.evaluate( @@ -314,7 +310,7 @@ where .pow_const(protocol.quotient.chunk_degree as u64) .powers(self.quotients.len()) .into_iter() - .zip(self.quotients.iter().cloned().map(Msm::base)) + .zip(self.quotients.iter().map(Msm::base)) .map(|(coeff, chunk)| chunk * &coeff) .sum::>(); match protocol.linearization {