diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 10a7b269..8f08883d 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -17,7 +17,7 @@ use halo2_proofs::{ use itertools::Itertools; use plonk_verifier::{ loader::{ - evm::{encode_calldata, EvmLoader, ExecutorBuilder}, + evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, @@ -570,7 +570,7 @@ fn gen_aggregation_evm_verifier( let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.deployment_code() + evm::compile_yul(&loader.yul_code()) } fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index 4ff5c682..b45bb59d 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -20,7 +20,7 @@ use halo2_proofs::{ }; use itertools::Itertools; use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader, ExecutorBuilder}, + loader::evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, @@ -222,7 +222,7 @@ fn gen_evm_verifier( let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.deployment_code() + evm::compile_yul(&loader.yul_code()) } fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { diff --git a/src/loader/evm.rs b/src/loader/evm.rs index fa80b97e..263da0e2 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -7,7 +7,8 @@ mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; pub use util::{ - encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, MemoryChunk, + compile_yul, encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, + MemoryChunk, }; pub use ethereum_types::U256; diff --git a/src/loader/evm/code.rs b/src/loader/evm/code.rs index 80dd5c71..840d1e67 100644 --- a/src/loader/evm/code.rs +++ b/src/loader/evm/code.rs @@ -1,7 +1,3 @@ -use crate::util::Itertools; -use ethereum_types::U256; -use std::{collections::HashMap, iter}; - pub enum Precompiled { BigModExp = 0x05, Bn254Add = 0x6, @@ -10,286 +6,70 @@ pub enum Precompiled { } #[derive(Clone, Debug)] -pub struct Code { - code: Vec, - constants: HashMap, - stack_len: usize, +pub struct YulCode { + // runtime code area + runtime: String, } -impl Code { - pub fn new(constants: impl IntoIterator) -> Self { - let mut code = Self { - code: Vec::new(), - constants: HashMap::new(), - stack_len: 0, - }; - let constants = constants.into_iter().collect_vec(); - for constant in constants.iter() { - code.push(*constant); - } - code.constants = HashMap::from_iter( - constants - .into_iter() - .enumerate() - .map(|(idx, value)| (value, idx)), - ); - code - } - - pub fn deployment(code: Vec) -> Vec { - let code_len = code.len(); - assert_ne!(code_len, 0); - - iter::empty() - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 14, - PUSH1, - 0, - CODECOPY, - ]) - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 0, - RETURN, - ]) - .chain(code) - .collect() - } - - pub fn stack_len(&self) -> usize { - self.stack_len - } - - pub fn len(&self) -> usize { - self.code.len() - } - - pub fn is_empty(&self) -> bool { - self.code.is_empty() - } - - pub fn push>(&mut self, value: T) -> &mut Self { - let value = value.into(); - match self.constants.get(&value) { - Some(idx) if (0..16).contains(&(self.stack_len - idx - 1)) => { - self.dup(self.stack_len - idx - 1); - } - _ => { - let mut bytes = vec![0; 32]; - value.to_big_endian(&mut bytes); - let bytes = bytes - .iter() - .position(|byte| *byte != 0) - .map_or(vec![0], |pos| bytes.drain(pos..).collect()); - self.code.push(PUSH1 - 1 + bytes.len() as u8); - self.code.extend(bytes); - self.stack_len += 1; - } +impl YulCode { + pub fn new() -> Self { + YulCode { + runtime: String::new(), } - self - } - - pub fn dup(&mut self, pos: usize) -> &mut Self { - assert!((0..16).contains(&pos)); - self.code.push(DUP1 + pos as u8); - self.stack_len += 1; - self } - pub fn swap(&mut self, pos: usize) -> &mut Self { - assert!((1..17).contains(&pos)); - self.code.push(SWAP1 - 1 + pos as u8); - self + pub fn code(&self, base_modulus: String, scalar_modulus: String) -> String { + format!( + " + object \"plonk_verifier\" {{ + code {{ + function allocate(size) -> ptr {{ + ptr := mload(0x40) + if eq(ptr, 0) {{ ptr := 0x60 }} + mstore(0x40, add(ptr, size)) + }} + let size := datasize(\"Runtime\") + let offset := allocate(size) + datacopy(offset, dataoffset(\"Runtime\"), size) + return(offset, size) + }} + object \"Runtime\" {{ + code {{ + let success:bool := true + let f_p := {base_modulus} + let f_q := {scalar_modulus} + function validate_ec_point(x, y) -> valid:bool {{ + {{ + let x_lt_p:bool := lt(x, {base_modulus}) + let y_lt_p:bool := lt(y, {base_modulus}) + valid := and(x_lt_p, y_lt_p) + }} + {{ + let x_is_zero:bool := eq(x, 0) + let y_is_zero:bool := eq(y, 0) + let x_or_y_is_zero:bool := or(x_is_zero, y_is_zero) + let x_and_y_is_not_zero:bool := not(x_or_y_is_zero) + valid := and(x_and_y_is_not_zero, valid) + }} + {{ + let y_square := mulmod(y, y, {base_modulus}) + let x_square := mulmod(x, x, {base_modulus}) + let x_cube := mulmod(x_square, x, {base_modulus}) + let x_cube_plus_3 := addmod(x_cube, 3, {base_modulus}) + let y_square_eq_x_cube_plus_3:bool := eq(x_cube_plus_3, y_square) + valid := and(y_square_eq_x_cube_plus_3, valid) + }} + }} + {} + }} + }} + }}", + self.runtime + ) } -} -impl From for Vec { - fn from(code: Code) -> Self { - code.code + pub fn runtime_append(&mut self, mut code: String) { + code.push('\n'); + self.runtime.push_str(&code); } } - -macro_rules! impl_opcodes { - ($($method:ident -> ($opcode:ident, $stack_len_diff:expr))*) => { - $( - #[allow(dead_code)] - impl Code { - pub fn $method(&mut self) -> &mut Self { - self.code.push($opcode); - self.stack_len = ((self.stack_len as isize) + $stack_len_diff) as usize; - self - } - } - )* - }; -} - -impl_opcodes!( - stop -> (STOP, 0) - add -> (ADD, -1) - mul -> (MUL, -1) - sub -> (SUB, -1) - div -> (DIV, -1) - sdiv -> (SDIV, -1) - r#mod -> (MOD, -1) - smod -> (SMOD, -1) - addmod -> (ADDMOD, -2) - mulmod -> (MULMOD, -2) - exp -> (EXP, -1) - signextend -> (SIGNEXTEND, -1) - lt -> (LT, -1) - gt -> (GT, -1) - slt -> (SLT, -1) - sgt -> (SGT, -1) - eq -> (EQ, -1) - iszero -> (ISZERO, 0) - and -> (AND, -1) - or -> (OR, -1) - xor -> (XOR, -1) - not -> (NOT, 0) - byte -> (BYTE, -1) - shl -> (SHL, -1) - shr -> (SHR, -1) - sar -> (SAR, -1) - keccak256 -> (SHA3, -1) - address -> (ADDRESS, 1) - balance -> (BALANCE, 0) - origin -> (ORIGIN, 1) - caller -> (CALLER, 1) - callvalue -> (CALLVALUE, 1) - calldataload -> (CALLDATALOAD, 0) - calldatasize -> (CALLDATASIZE, 1) - calldatacopy -> (CALLDATACOPY, -3) - codesize -> (CODESIZE, 1) - codecopy -> (CODECOPY, -3) - gasprice -> (GASPRICE, 1) - extcodesize -> (EXTCODESIZE, 0) - extcodecopy -> (EXTCODECOPY, -4) - returndatasize -> (RETURNDATASIZE, 1) - returndatacopy -> (RETURNDATACOPY, -3) - extcodehash -> (EXTCODEHASH, 0) - blockhash -> (BLOCKHASH, 0) - coinbase -> (COINBASE, 1) - timestamp -> (TIMESTAMP, 1) - number -> (NUMBER, 1) - difficulty -> (DIFFICULTY, 1) - gaslimit -> (GASLIMIT, 1) - chainid -> (CHAINID, 1) - selfbalance -> (SELFBALANCE, 1) - basefee -> (BASEFEE, 1) - pop -> (POP, -1) - mload -> (MLOAD, 0) - mstore -> (MSTORE, -2) - mstore8 -> (MSTORE8, -2) - sload -> (SLOAD, 0) - sstore -> (SSTORE, -2) - jump -> (JUMP, -1) - jumpi -> (JUMPI, -2) - pc -> (PC, 1) - msize -> (MSIZE, 1) - gas -> (GAS, 1) - jumpdest -> (JUMPDEST, 0) - log0 -> (LOG0, -2) - log1 -> (LOG1, -3) - log2 -> (LOG2, -4) - log3 -> (LOG3, -5) - log4 -> (LOG4, -6) - create -> (CREATE, -2) - call -> (CALL, -6) - callcode -> (CALLCODE, -6) - r#return -> (RETURN, -2) - delegatecall -> (DELEGATECALL, -5) - create2 -> (CREATE2, -3) - staticcall -> (STATICCALL, -5) - revert -> (REVERT, -2) - selfdestruct -> (SELFDESTRUCT, -1) -); - -const STOP: u8 = 0x00; -const ADD: u8 = 0x01; -const MUL: u8 = 0x02; -const SUB: u8 = 0x03; -const DIV: u8 = 0x04; -const SDIV: u8 = 0x05; -const MOD: u8 = 0x06; -const SMOD: u8 = 0x07; -const ADDMOD: u8 = 0x08; -const MULMOD: u8 = 0x09; -const EXP: u8 = 0x0A; -const SIGNEXTEND: u8 = 0x0B; -const LT: u8 = 0x10; -const GT: u8 = 0x11; -const SLT: u8 = 0x12; -const SGT: u8 = 0x13; -const EQ: u8 = 0x14; -const ISZERO: u8 = 0x15; -const AND: u8 = 0x16; -const OR: u8 = 0x17; -const XOR: u8 = 0x18; -const NOT: u8 = 0x19; -const BYTE: u8 = 0x1A; -const SHL: u8 = 0x1B; -const SHR: u8 = 0x1C; -const SAR: u8 = 0x1D; -const SHA3: u8 = 0x20; -const ADDRESS: u8 = 0x30; -const BALANCE: u8 = 0x31; -const ORIGIN: u8 = 0x32; -const CALLER: u8 = 0x33; -const CALLVALUE: u8 = 0x34; -const CALLDATALOAD: u8 = 0x35; -const CALLDATASIZE: u8 = 0x36; -const CALLDATACOPY: u8 = 0x37; -const CODESIZE: u8 = 0x38; -const CODECOPY: u8 = 0x39; -const GASPRICE: u8 = 0x3A; -const EXTCODESIZE: u8 = 0x3B; -const EXTCODECOPY: u8 = 0x3C; -const RETURNDATASIZE: u8 = 0x3D; -const RETURNDATACOPY: u8 = 0x3E; -const EXTCODEHASH: u8 = 0x3F; -const BLOCKHASH: u8 = 0x40; -const COINBASE: u8 = 0x41; -const TIMESTAMP: u8 = 0x42; -const NUMBER: u8 = 0x43; -const DIFFICULTY: u8 = 0x44; -const GASLIMIT: u8 = 0x45; -const CHAINID: u8 = 0x46; -const SELFBALANCE: u8 = 0x47; -const BASEFEE: u8 = 0x48; -const POP: u8 = 0x50; -const MLOAD: u8 = 0x51; -const MSTORE: u8 = 0x52; -const MSTORE8: u8 = 0x53; -const SLOAD: u8 = 0x54; -const SSTORE: u8 = 0x55; -const JUMP: u8 = 0x56; -const JUMPI: u8 = 0x57; -const PC: u8 = 0x58; -const MSIZE: u8 = 0x59; -const GAS: u8 = 0x5A; -const JUMPDEST: u8 = 0x5B; -const PUSH1: u8 = 0x60; -const DUP1: u8 = 0x80; -const SWAP1: u8 = 0x90; -const LOG0: u8 = 0xA0; -const LOG1: u8 = 0xA1; -const LOG2: u8 = 0xA2; -const LOG3: u8 = 0xA3; -const LOG4: u8 = 0xA4; -const CREATE: u8 = 0xF0; -const CALL: u8 = 0xF1; -const CALLCODE: u8 = 0xF2; -const RETURN: u8 = 0xF3; -const DELEGATECALL: u8 = 0xF4; -const CREATE2: u8 = 0xF5; -const STATICCALL: u8 = 0xFA; -const REVERT: u8 = 0xFD; -const SELFDESTRUCT: u8 = 0xFF; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 9e8b8e2a..c630a468 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -1,9 +1,11 @@ use crate::{ - loader::evm::{ - code::{Code, Precompiled}, - fe_to_u256, modulus, + loader::{ + evm::{ + code::{Precompiled, YulCode}, + fe_to_u256, modulus, u256_to_fe, + }, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, }, - loader::{evm::u256_to_fe, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, util::{ arithmetic::{CurveAffine, FieldOps, PrimeField}, Itertools, @@ -11,6 +13,7 @@ use crate::{ Error, }; use ethereum_types::{U256, U512}; +use hex; use std::{ cell::RefCell, collections::HashMap, @@ -50,24 +53,29 @@ impl Value { pub struct EvmLoader { base_modulus: U256, scalar_modulus: U256, - code: RefCell, + code: RefCell, ptr: RefCell, cache: RefCell>, #[cfg(test)] gas_metering_ids: RefCell>, } +fn hex_encode_u256(value: &U256) -> String { + let mut bytes = [0; 32]; + value.to_big_endian(&mut bytes); + format!("0x{}", hex::encode(bytes)) +} + impl EvmLoader { pub fn new() -> Rc where - Base: PrimeField, + Base: PrimeField, Scalar: PrimeField, { let base_modulus = modulus::(); let scalar_modulus = modulus::(); - let code = Code::new([1.into(), base_modulus, scalar_modulus - 1, scalar_modulus]) - .push(1) - .to_owned(); + let code = YulCode::new(); + Rc::new(Self { base_modulus, scalar_modulus, @@ -79,22 +87,16 @@ impl EvmLoader { }) } - pub fn deployment_code(self: &Rc) -> Vec { - Code::deployment(self.runtime_code()) - } - - pub fn runtime_code(self: &Rc) -> Vec { - let mut code = self.code.borrow().clone(); - let dst = code.len() + 9; - code.push(dst) - .jumpi() - .push(0) - .push(0) - .revert() - .jumpdest() - .stop() - .to_owned() - .into() + pub fn yul_code(self: &Rc) -> String { + let code = " + if not(success) { revert(0, 0) } + return(0, 0)" + .to_string(); + self.code.borrow_mut().runtime_append(code); + self.code.borrow().code( + hex_encode_u256(&self.base_modulus), + hex_encode_u256(&self.scalar_modulus), + ) } pub fn allocate(self: &Rc, size: usize) -> usize { @@ -103,121 +105,64 @@ impl EvmLoader { ptr } - pub(crate) fn scalar_modulus(&self) -> U256 { - self.scalar_modulus - } - pub(crate) fn ptr(&self) -> usize { *self.ptr.borrow() } - pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { + pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { self.code.borrow_mut() } - pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { - let value = if matches!( - value, - Value::Constant(_) | Value::Memory(_) | Value::Negated(_) - ) { - value - } else { - let identifier = value.identifier(); - let some_ptr = self.cache.borrow().get(&identifier).cloned(); - let ptr = if let Some(ptr) = some_ptr { - ptr - } else { - self.push(&Scalar { - loader: self.clone(), - value, - }); - let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); - self.cache.borrow_mut().insert(identifier, ptr); - ptr - }; - Value::Memory(ptr) - }; - Scalar { - loader: self.clone(), - value, - } - } - - fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { - EcPoint { - loader: self.clone(), - value, - } - } - - fn push(self: &Rc, scalar: &Scalar) { + fn push(self: &Rc, scalar: &Scalar) -> String { match scalar.value.clone() { Value::Constant(constant) => { - self.code.borrow_mut().push(constant); + format!("{constant}") } Value::Memory(ptr) => { - self.code.borrow_mut().push(ptr).mload(); + format!("mload({ptr:#x})") } Value::Negated(value) => { - self.push(&self.scalar(*value)); - self.code.borrow_mut().push(self.scalar_modulus).sub(); + let v = self.push(&self.scalar(*value)); + format!("sub(f_q, {v})") } Value::Sum(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().addmod(); + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("addmod({lhs}, {rhs}, f_q)") } Value::Product(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().mulmod(); + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("mulmod({lhs}, {rhs}, f_q)") } } } pub fn calldataload_scalar(self: &Rc, offset: usize) -> Scalar { let ptr = self.allocate(0x20); - self.code - .borrow_mut() - .push(self.scalar_modulus) - .push(offset) - .calldataload() - .r#mod() - .push(ptr) - .mstore(); + let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))"); + self.code.borrow_mut().runtime_append(code); self.scalar(Value::Memory(ptr)) } pub fn calldataload_ec_point(self: &Rc, offset: usize) -> EcPoint { - let ptr = self.allocate(0x40); - self.code - .borrow_mut() - // [..., success] - .push(offset) - // [..., success, x_cd_ptr] - .calldataload() - // [..., success, x] - .dup(0) - // [..., success, x, x] - .push(ptr) - // [..., success, x, x, x_ptr] - .mstore() - // [..., success, x] - .push(offset + 0x20) - // [..., success, x, y_cd_ptr] - .calldataload() - // [..., success, x, y] - .dup(0) - // [..., success, x, y, y] - .push(ptr + 0x20) - // [..., success, x, y, y, y_ptr] - .mstore(); - // [..., success, x, y] - self.validate_ec_point(); - self.ec_point(Value::Memory(ptr)) + let x_ptr = self.allocate(0x40); + let y_ptr = x_ptr + 0x20; + let x_cd_ptr = offset; + let y_cd_ptr = offset + 0x20; + let validate_code = self.validate_ec_point(); + let code = format!( + " + {{ + let x := calldataload({x_cd_ptr:#x}) + mstore({x_ptr:#x}, x) + let y := calldataload({y_cd_ptr:#x}) + mstore({y_ptr:#x}, y) + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); + self.ec_point(Value::Memory(x_ptr)) } pub fn ec_point_from_limbs( @@ -226,124 +171,94 @@ impl EvmLoader { y_limbs: [&Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); - for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { - for (idx, limb) in limbs.into_iter().enumerate() { - self.push(limb); - // [..., success, acc] - if idx > 0 { - self.code - .borrow_mut() - .push(idx * BITS) - // [..., success, acc, limb_i, shift] - .shl() - // [..., success, acc, limb_i << shift] - .add(); - // [..., success, acc] - } + let mut code = String::new(); + for (idx, limb) in x_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let x := {limb_i}\n").as_str()); + } else { + code.push_str(format!("x := add(x, shl({shift}, {limb_i}))\n").as_str()); + } + } + let x_ptr = ptr; + code.push_str(format!("mstore({x_ptr}, x)\n").as_str()); + for (idx, limb) in y_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let y := {limb_i}\n").as_str()); + } else { + code.push_str(format!("y := add(y, shl({shift}, {limb_i}))\n").as_str()); } - self.code - .borrow_mut() - // [..., success, coordinate] - .dup(0) - // [..., success, coordinate, coordinate] - .push(ptr) - // [..., success, coordinate, coordinate, ptr] - .mstore(); - // [..., success, coordinate] } - // [..., success, x, y] - self.validate_ec_point(); + let y_ptr = ptr + 0x20; + code.push_str(format!("mstore({y_ptr}, y)\n").as_str()); + let validate_code = self.validate_ec_point(); + let code = format!( + "{{ + {code} + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); self.ec_point(Value::Memory(ptr)) } - fn validate_ec_point(self: &Rc) { - self.code - .borrow_mut() - // [..., success, x, y] - .push(self.base_modulus) - // [..., success, x, y, p] - .dup(2) - // [..., success, x, y, p, x] - .lt() - // [..., success, x, y, x_lt_p] - .push(self.base_modulus) - // [..., success, x, y, x_lt_p, p] - .dup(2) - // [..., success, x, y, x_lt_p, p, y] - .lt() - // [..., success, x, y, x_lt_p, y_lt_p] - .and() - // [..., success, x, y, valid] - .dup(2) - // [..., success, x, y, valid, x] - .iszero() - // [..., success, x, y, valid, x_is_zero] - .dup(2) - // [..., success, x, y, valid, x_is_zero, y] - .iszero() - // [..., success, x, y, valid, x_is_zero, y_is_zero] - .or() - // [..., success, x, y, valid, x_or_y_is_zero] - .not() - // [..., success, x, y, valid, x_and_y_is_not_zero] - .and() - // [..., success, x, y, valid] - .push(self.base_modulus) - // [..., success, x, y, valid, p] - .dup(2) - // [..., success, x, y, valid, p, y] - .dup(0) - // [..., success, x, y, valid, p, y, y] - .mulmod() - // [..., success, x, y, valid, y_square] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p] - .push(3) - // [..., success, x, y, valid, y_square, p, 3] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p] - .dup(6) - // [..., success, x, y, valid, y_square, p, 3, p, x] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p, x, p] - .dup(1) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x] - .dup(0) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x, x] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, p, x, x_square] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, x_cube] - .addmod() - // [..., success, x, y, valid, y_square, x_cube_plus_3] - .eq() - // [..., success, x, y, valid, y_square_eq_x_cube_plus_3] - .and() - // [..., success, x, y, valid] - .swap(2) - // [..., success, valid, y, x] - .pop() - // [..., success, valid, y] - .pop() - // [..., success, valid] - .and(); + fn validate_ec_point(self: &Rc) -> String { + "success := and(validate_ec_point(x, y), success)".to_string() + } + + pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { + let value = if matches!( + value, + Value::Constant(_) | Value::Memory(_) | Value::Negated(_) + ) { + value + } else { + let identifier = value.identifier(); + let some_ptr = self.cache.borrow().get(&identifier).cloned(); + let ptr = if let Some(ptr) = some_ptr { + ptr + } else { + let v = self.push(&Scalar { + loader: self.clone(), + value, + }); + let ptr = self.allocate(0x20); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {v})")); + self.cache.borrow_mut().insert(identifier, ptr); + ptr + }; + Value::Memory(ptr) + }; + Scalar { + loader: self.clone(), + value, + } + } + + fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { + EcPoint { + loader: self.clone(), + value, + } } pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { let hash_ptr = self.allocate(0x20); - self.code - .borrow_mut() - .push(len) - .push(ptr) - .keccak256() - .push(hash_ptr) - .mstore(); + let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))"); + self.code.borrow_mut().runtime_append(code); hash_ptr } pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { - self.push(scalar); - self.code.borrow_mut().push(ptr).mstore(); + let scalar = self.push(scalar); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {scalar})")); } pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { @@ -356,26 +271,26 @@ impl EvmLoader { let ptr = self.allocate(0x40); match value.value { Value::Constant((x, y)) => { - self.code - .borrow_mut() - .push(x) - .push(ptr) - .mstore() - .push(y) - .push(ptr + 0x20) - .mstore(); + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let x = hex_encode_u256(&x); + let y = hex_encode_u256(&y); + let code = format!( + "mstore({x_ptr:#x}, {x}) + mstore({y_ptr:#x}, {y})" + ); + self.code.borrow_mut().runtime_append(code); } Value::Memory(src_ptr) => { - self.code - .borrow_mut() - .push(src_ptr) - .mload() - .push(ptr) - .mstore() - .push(src_ptr + 0x20) - .mload() - .push(ptr + 0x20) - .mstore(); + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let src_x = src_ptr; + let src_y = src_ptr + 0x20; + let code = format!( + "mstore({x_ptr:#x}, mload({src_x:#x})) + mstore({y_ptr:#x}, mload({src_y:#x}))" + ); + self.code.borrow_mut().runtime_append(code); } Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { unreachable!() @@ -391,16 +306,9 @@ impl EvmLoader { Precompiled::Bn254ScalarMul => (0x60, 0x40), Precompiled::Bn254Pairing => (0x180, 0x20), }; - self.code - .borrow_mut() - .push(rd_len) - .push(rd_ptr) - .push(cd_len) - .push(cd_ptr) - .push(precompile as usize) - .gas() - .staticcall() - .and(); + let a = precompile as usize; + let code = format!("success := and(eq(staticcall(gas(), {a:#x}, {cd_ptr:#x}, {cd_len:#x}, {rd_ptr:#x}, {rd_len:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); } fn invert(self: &Rc, scalar: &Scalar) -> Scalar { @@ -441,38 +349,41 @@ impl EvmLoader { ) { let rd_ptr = self.dup_ec_point(lhs).ptr(); self.allocate(0x80); - self.code - .borrow_mut() - .push(g2.0) - .push(rd_ptr + 0x40) - .mstore() - .push(g2.1) - .push(rd_ptr + 0x60) - .mstore() - .push(g2.2) - .push(rd_ptr + 0x80) - .mstore() - .push(g2.3) - .push(rd_ptr + 0xa0) - .mstore(); + let g2_0 = hex_encode_u256(&g2.0); + let g2_0_ptr = rd_ptr + 0x40; + let g2_1 = hex_encode_u256(&g2.1); + let g2_1_ptr = rd_ptr + 0x60; + let g2_2 = hex_encode_u256(&g2.2); + let g2_2_ptr = rd_ptr + 0x80; + let g2_3 = hex_encode_u256(&g2.3); + let g2_3_ptr = rd_ptr + 0xa0; + let code = format!( + "mstore({g2_0_ptr:#x}, {g2_0}) + mstore({g2_1_ptr:#x}, {g2_1}) + mstore({g2_2_ptr:#x}, {g2_2}) + mstore({g2_3_ptr:#x}, {g2_3})" + ); + self.code.borrow_mut().runtime_append(code); self.dup_ec_point(rhs); self.allocate(0x80); - self.code - .borrow_mut() - .push(minus_s_g2.0) - .push(rd_ptr + 0x100) - .mstore() - .push(minus_s_g2.1) - .push(rd_ptr + 0x120) - .mstore() - .push(minus_s_g2.2) - .push(rd_ptr + 0x140) - .mstore() - .push(minus_s_g2.3) - .push(rd_ptr + 0x160) - .mstore(); + let minus_s_g2_0 = hex_encode_u256(&minus_s_g2.0); + let minus_s_g2_0_ptr = rd_ptr + 0x100; + let minus_s_g2_1 = hex_encode_u256(&minus_s_g2.1); + let minus_s_g2_1_ptr = rd_ptr + 0x120; + let minus_s_g2_2 = hex_encode_u256(&minus_s_g2.2); + let minus_s_g2_2_ptr = rd_ptr + 0x140; + let minus_s_g2_3 = hex_encode_u256(&minus_s_g2.3); + let minus_s_g2_3_ptr = rd_ptr + 0x160; + let code = format!( + "mstore({minus_s_g2_0_ptr:#x}, {minus_s_g2_0}) + mstore({minus_s_g2_1_ptr:#x}, {minus_s_g2_1}) + mstore({minus_s_g2_2_ptr:#x}, {minus_s_g2_2}) + mstore({minus_s_g2_3_ptr:#x}, {minus_s_g2_3})" + ); + self.code.borrow_mut().runtime_append(code); self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr); - self.code.borrow_mut().push(rd_ptr).mload().and(); + let code = format!("success := and(eq(mload({rd_ptr:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); } fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { @@ -525,21 +436,16 @@ impl EvmLoader { self.gas_metering_ids .borrow_mut() .push(identifier.to_string()); - self.code.borrow_mut().gas().swap(1); + let code = format!("let {identifier} := gas()"); + self.code.borrow_mut().runtime_append(code); } fn end_gas_metering(self: &Rc) { - self.code - .borrow_mut() - .swap(1) - .push(9) - .gas() - .swap(2) - .sub() - .sub() - .push(0) - .push(0) - .log1(); + let code = format!( + "log1(0, 0, sub({}, gas()))", + self.gas_metering_ids.borrow().last().unwrap() + ); + self.code.borrow_mut().runtime_append(code); } pub fn print_gas_metering(self: &Rc, costs: Vec) { @@ -745,7 +651,7 @@ impl PartialEq for Scalar { impl> LoadedScalar for Scalar { type Loader = Rc; - fn loader(&self) -> &Rc { + fn loader(&self) -> &Self::Loader { &self.loader } } @@ -753,7 +659,7 @@ impl> LoadedScalar for Scalar { impl EcPointLoader for Rc where C: CurveAffine, - C::ScalarExt: PrimeField, + C::Scalar: PrimeField, { type LoadedEcPoint = EcPoint; @@ -802,48 +708,39 @@ impl> ScalarLoader for Rc { let push_addend = |(coeff, value): &(F, &Scalar)| { assert_ne!(*coeff, F::zero()); match (*coeff == F::one(), &value.value) { - (true, _) => { - self.push(value); - } - (false, Value::Constant(value)) => { - self.push(&self.scalar(Value::Constant(fe_to_u256( - *coeff * u256_to_fe::(*value), - )))); - } + (true, _) => self.push(value), + (false, Value::Constant(value)) => self.push(&self.scalar(Value::Constant( + fe_to_u256(*coeff * u256_to_fe::(*value)), + ))), (false, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(value); - self.code.borrow_mut().mulmod(); + let value = self.push(value); + let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({value}, {coeff}, f_q)") } } }; let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); } let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); self.scalar(Value::Memory(ptr)) } @@ -863,63 +760,64 @@ impl> ScalarLoader for Rc { (_, Value::Constant(lhs), Value::Constant(rhs)) => { self.push(&self.scalar(Value::Constant(fe_to_u256( *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), - )))); + )))) } (_, value @ Value::Memory(_), Value::Constant(constant)) | (_, Value::Constant(constant), value @ Value::Memory(_)) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256( + let v1 = self.push(&self.scalar(value.clone())); + let v2 = self.push(&self.scalar(Value::Constant(fe_to_u256( *coeff * u256_to_fe::(*constant), )))); - self.push(&self.scalar(value.clone())); - self.code.borrow_mut().mulmod(); + format!("mulmod({v1}, {v2}, f_q)") } (true, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(lhs); - self.push(rhs); - self.code.borrow_mut().mulmod(); + let rhs = self.push(rhs); + let lhs = self.push(lhs); + format!("mulmod({rhs}, {lhs}, f_q)") } (false, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus).dup(0); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(lhs); - self.code.borrow_mut().mulmod(); - self.push(rhs); - self.code.borrow_mut().mulmod(); + let rhs = self.push(rhs); + let lhs = self.push(lhs); + let value = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({rhs}, mulmod({lhs}, {value}, f_q), f_q)") } } }; let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); } let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); self.scalar(Value::Memory(ptr)) } + // batch_invert algorithm + // n := values.len() - 1 + // input : values[0], ..., values[n] + // output : values[0]^{-1}, ..., values[n]^{-1} + // 1. products[i] <- values[0] * ... * values[i], i = 1, ..., n + // 2. inv <- (products[n])^{-1} + // 3. v_n <- values[n] + // 4. values[n] <- products[n - 1] * inv (values[n]^{-1}) + // 5. inv <- v_n * inv fn batch_invert<'a>(values: impl IntoIterator) { let values = values.into_iter().collect_vec(); let loader = &values.first().unwrap().loader; @@ -931,29 +829,35 @@ impl> ScalarLoader for Rc { ) .collect_vec(); - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); + let initial_value = loader.push(products.first().unwrap()); + let mut code = format!("let prod := {initial_value}\n"); + for (_, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { + let v = loader.push(value); + let ptr = product.ptr(); + code.push_str( + format!( + " + prod := mulmod({v}, prod, f_q) + mstore({ptr:#x}, prod) + " + ) + .as_str(), + ); } - - loader.push(products.first().unwrap()); - for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { - loader.push(value); - loader.code.borrow_mut().mulmod(); - if idx < values.len() - 2 { - loader.code.borrow_mut().dup(0); - } - loader.code.borrow_mut().push(product.ptr()).mstore(); - } - - let inv = loader.invert(products.last().unwrap()); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(&inv); + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + + let inv = loader.push(&loader.invert(products.last().unwrap())); + + let mut code = format!( + " + let inv := {inv} + let v + " + ); for (value, product) in values.iter().rev().zip( products .iter() @@ -963,22 +867,29 @@ impl> ScalarLoader for Rc { .chain(iter::once(None)), ) { if let Some(product) = product { - loader.push(value); - loader - .code - .borrow_mut() - .dup(2) - .dup(2) - .push(product.ptr()) - .mload() - .mulmod() - .push(value.ptr()) - .mstore() - .mulmod(); + let val_ptr = value.ptr(); + let prod_ptr = product.ptr(); + let v = loader.push(value); + code.push_str( + format!( + " + v := {v} + mstore({val_ptr}, mulmod(mload({prod_ptr:#x}), inv, f_q)) + inv := mulmod(v, inv, f_q) + " + ) + .as_str(), + ); } else { - loader.code.borrow_mut().push(value.ptr()).mstore(); + let ptr = value.ptr(); + code.push_str(format!("mstore({ptr:#x}, inv)\n").as_str()); } } + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); } } diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index f81c5f54..e6f3703e 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -3,7 +3,6 @@ use crate::{ util::Itertools, }; use ethereum_types::{Address, U256}; -use revm::{AccountInfo, Bytecode}; use std::env::var_os; mod tui; @@ -15,28 +14,26 @@ fn debug() -> bool { ) } -pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { +pub fn execute(deployment_code: Vec, calldata: Vec) -> (bool, u64, Vec) { assert!( - code.len() <= 0x6000, + deployment_code.len() <= 0x6000, "Contract size {} exceeds the limit 24576", - code.len() + deployment_code.len() ); let debug = debug(); let caller = Address::from_low_u64_be(0xfe); - let callee = Address::from_low_u64_be(0xff); let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) .set_debugger(debug) .build(); - evm.db_mut().insert_account_info( - callee, - AccountInfo::new(0.into(), 1, Bytecode::new_raw(code.into())), - ); - - let result = evm.call_raw(caller, callee, calldata.into(), 0.into()); + let contract = evm + .deploy(caller, deployment_code.into(), 0.into()) + .address + .unwrap(); + let result = evm.call_raw(caller, contract, calldata.into(), 0.into()); let costs = result .logs diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs index 0a772513..a7df5209 100644 --- a/src/loader/evm/util.rs +++ b/src/loader/evm/util.rs @@ -3,7 +3,11 @@ use crate::{ util::{arithmetic::PrimeField, Itertools}, }; use ethereum_types::U256; -use std::iter; +use std::{ + io::Write, + iter, + process::{Command, Stdio}, +}; pub(crate) mod executor; @@ -94,3 +98,37 @@ pub fn estimate_gas(cost: Cost) -> usize { intrinsic_cost + calldata_cost + ec_operation_cost } + +pub fn compile_yul(code: &str) -> Vec { + let mut cmd = Command::new("solc") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .arg("--bin") + .arg("--yul") + .arg("-") + .spawn() + .unwrap(); + cmd.stdin + .take() + .unwrap() + .write_all(code.as_bytes()) + .unwrap(); + let output = cmd.wait_with_output().unwrap().stdout; + let binary = *split_by_ascii_whitespace(&output).last().unwrap(); + hex::decode(binary).unwrap() +} + +fn split_by_ascii_whitespace(bytes: &[u8]) -> Vec<&[u8]> { + let mut split = Vec::new(); + let mut start = None; + for (idx, byte) in bytes.iter().enumerate() { + if byte.is_ascii_whitespace() { + if let Some(start) = start.take() { + split.push(&bytes[start..idx]); + } + } else if start.is_none() { + start = Some(idx); + } + } + split +} diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs index de3e2a06..3a3ba096 100644 --- a/src/pcs/kzg/decider.rs +++ b/src/pcs/kzg/decider.rs @@ -132,14 +132,8 @@ mod evm { let hash_ptr = loader.keccak256(lhs[0].ptr(), lhs.len() * 0x80); let challenge_ptr = loader.allocate(0x20); - loader - .code_mut() - .push(loader.scalar_modulus()) - .push(hash_ptr) - .mload() - .r#mod() - .push(challenge_ptr) - .mstore(); + let code = format!("mstore({challenge_ptr}, mod(mload({hash_ptr}), f_q))"); + loader.code_mut().runtime_append(code); let challenge = loader.scalar(Value::Memory(challenge_ptr)); let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs index 4e57d369..1ab52083 100644 --- a/src/system/halo2/test/kzg/evm.rs +++ b/src/system/halo2/test/kzg/evm.rs @@ -24,7 +24,7 @@ macro_rules! halo2_kzg_evm_verify { use halo2_proofs::poly::commitment::ParamsProver; use std::rc::Rc; use $crate::{ - loader::evm::{encode_calldata, execute, EvmLoader}, + loader::evm::{compile_yul, encode_calldata, execute, EvmLoader}, system::halo2::{ test::kzg::{BITS, LIMBS}, transcript::evm::EvmTranscript, @@ -34,7 +34,7 @@ macro_rules! halo2_kzg_evm_verify { }; let loader = EvmLoader::new::(); - let runtime_code = { + let deployment_code = { let svk = $params.get_g()[0].into(); let dk = ($params.g2(), $params.s_g2()).into(); let protocol = $protocol.loaded(&loader); @@ -49,11 +49,11 @@ macro_rules! halo2_kzg_evm_verify { .unwrap(); <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.runtime_code() + compile_yul(&loader.yul_code()) }; let (accept, total_cost, costs) = - execute(runtime_code, encode_calldata($instances, &$proof)); + execute(deployment_code, encode_calldata($instances, &$proof)); loader.print_gas_metering(costs); println!("Total gas cost: {}", total_cost); diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs index e461bc05..32d60b8f 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/src/system/halo2/transcript/evm.rs @@ -73,11 +73,9 @@ where fn squeeze_challenge(&mut self) -> Scalar { let len = if self.buf.len() == 0x20 { assert_eq!(self.loader.ptr(), self.buf.end()); - self.loader - .code_mut() - .push(1) - .push(self.buf.end()) - .mstore8(); + let buf_end = self.buf.end(); + let code = format!("mstore8({buf_end}, 1)"); + self.loader.code_mut().runtime_append(code); 0x21 } else { self.buf.len() @@ -86,17 +84,14 @@ where let challenge_ptr = self.loader.allocate(0x20); let dup_hash_ptr = self.loader.allocate(0x20); - self.loader - .code_mut() - .push(hash_ptr) - .mload() - .push(self.loader.scalar_modulus()) - .dup(1) - .r#mod() - .push(challenge_ptr) - .mstore() - .push(dup_hash_ptr) - .mstore(); + let code = format!( + "{{ + let hash := mload({hash_ptr:#x}) + mstore({challenge_ptr:#x}, mod(hash, f_q)) + mstore({dup_hash_ptr:#x}, hash) + }}" + ); + self.loader.code_mut().runtime_append(code); self.buf.reset(dup_hash_ptr); self.buf.extend(0x20);