diff --git a/zkevm-circuits/src/evm_circuit.rs b/zkevm-circuits/src/evm_circuit.rs index c88a1d4ca4..0ae23cca81 100644 --- a/zkevm-circuits/src/evm_circuit.rs +++ b/zkevm-circuits/src/evm_circuit.rs @@ -432,6 +432,8 @@ pub(crate) mod test { FixedTableTag::Range512, FixedTableTag::SignByte, FixedTableTag::ResponsibleOpcode, + FixedTableTag::Bitslevel, + FixedTableTag::Pow64, ], ) } diff --git a/zkevm-circuits/src/evm_circuit/execution.rs b/zkevm-circuits/src/evm_circuit/execution.rs index 4db5031865..eb22e0be22 100644 --- a/zkevm-circuits/src/evm_circuit/execution.rs +++ b/zkevm-circuits/src/evm_circuit/execution.rs @@ -34,6 +34,7 @@ mod mul; mod pc; mod pop; mod push; +mod shl; mod signed_comparator; mod signextend; mod stop; @@ -58,6 +59,7 @@ use mul::MulGadget; use pc::PcGadget; use pop::PopGadget; use push::PushGadget; +use shl::ShlGadget; use signed_comparator::SignedComparatorGadget; use signextend::SignextendGadget; use stop::StopGadget; @@ -104,6 +106,7 @@ pub(crate) struct ExecutionConfig { pc_gadget: PcGadget, pop_gadget: PopGadget, push_gadget: PushGadget, + shl_gadget: ShlGadget, signed_comparator_gadget: SignedComparatorGadget, signextend_gadget: SignextendGadget, stop_gadget: StopGadget, @@ -232,6 +235,7 @@ impl ExecutionConfig { pc_gadget: configure_gadget!(), pop_gadget: configure_gadget!(), push_gadget: configure_gadget!(), + shl_gadget: configure_gadget!(), signed_comparator_gadget: configure_gadget!(), signextend_gadget: configure_gadget!(), stop_gadget: configure_gadget!(), @@ -464,6 +468,7 @@ impl ExecutionConfig { ExecutionState::ADD => assign_exec_step!(self.add_gadget), ExecutionState::MUL => assign_exec_step!(self.mul_gadget), ExecutionState::BITWISE => assign_exec_step!(self.bitwise_gadget), + ExecutionState::SHL => assign_exec_step!(self.shl_gadget), ExecutionState::SIGNEXTEND => { assign_exec_step!(self.signextend_gadget) } diff --git a/zkevm-circuits/src/evm_circuit/execution/shl.rs b/zkevm-circuits/src/evm_circuit/execution/shl.rs new file mode 100644 index 0000000000..27c9082a77 --- /dev/null +++ b/zkevm-circuits/src/evm_circuit/execution/shl.rs @@ -0,0 +1,114 @@ +use crate::{ + evm_circuit::{ + execution::ExecutionGadget, + step::ExecutionState, + util::{ + common_gadget::SameContextGadget, + constraint_builder::{ConstraintBuilder, StepStateTransition, Transition::Delta}, + math_gadget::ShlWordsGadget, + }, + witness::{Block, Call, ExecStep, Transaction}, + }, + util::Expr, +}; +use halo2::{arithmetic::FieldExt, circuit::Region, plonk::Error}; + +#[derive(Clone, Debug)] +pub(crate) struct ShlGadget { + same_context: SameContextGadget, + shl_words: ShlWordsGadget, +} + +impl ExecutionGadget for ShlGadget { + const NAME: &'static str = "SHL"; + + const EXECUTION_STATE: ExecutionState = ExecutionState::SHL; + + fn configure(cb: &mut ConstraintBuilder) -> Self { + let opcode = cb.query_cell(); + + let a = cb.query_word(); + let shift = cb.query_word(); + + cb.stack_pop(shift.expr()); + cb.stack_pop(a.expr()); + let shl_words = ShlWordsGadget::construct(cb, a, shift); + cb.stack_push(shl_words.b().expr()); + + let step_state_transition = StepStateTransition { + rw_counter: Delta(3.expr()), + program_counter: Delta(1.expr()), + stack_pointer: Delta(1.expr()), + ..Default::default() + }; + let same_context = SameContextGadget::construct(cb, opcode, step_state_transition, None); + + Self { + same_context, + shl_words, + } + } + + fn assign_exec_step( + &self, + region: &mut Region<'_, F>, + offset: usize, + block: &Block, + _: &Transaction, + _: &Call, + step: &ExecStep, + ) -> Result<(), Error> { + self.same_context.assign_exec_step(region, offset, step)?; + let indices = [step.rw_indices[0], step.rw_indices[1], step.rw_indices[2]]; + let [shift, a, b] = indices.map(|idx| block.rws[idx].stack_value()); + self.shl_words.assign(region, offset, a, shift, b) + } +} + +#[cfg(test)] +mod test { + use crate::evm_circuit::test::rand_word; + use crate::test_util::run_test_circuits; + use eth_types::evm_types::OpcodeId; + use eth_types::{bytecode, Word}; + use rand::Rng; + + fn test_ok(opcode: OpcodeId, a: Word, shift: Word) { + let bytecode = bytecode! { + PUSH32(a) + PUSH32(shift) + #[start] + .write_op(opcode) + STOP + }; + assert_eq!(run_test_circuits(bytecode), Ok(())); + } + + #[test] + fn shl_gadget_simple() { + test_ok(OpcodeId::SHL, 0x02FF.into(), 0x1.into()); + } + + #[test] + fn shl_gadget_rand_normal_shift() { + let a = rand_word(); + let mut rng = rand::thread_rng(); + let shift = rng.gen_range(0..=255); + test_ok(OpcodeId::SHL, a, shift.into()); + } + + #[test] + fn shl_gadget_rand_overflow_shift() { + let a = rand_word(); + let shift = Word::from_big_endian(&[255u8; 32]); + test_ok(OpcodeId::SHL, a, shift); + } + + //this testcase manage to check the split is correct. + #[test] + fn shl_gadget_constant_shift() { + let a = rand_word(); + test_ok(OpcodeId::SHL, a, 8.into()); + test_ok(OpcodeId::SHL, a, 64.into()); + } +} diff --git a/zkevm-circuits/src/evm_circuit/table.rs b/zkevm-circuits/src/evm_circuit/table.rs index e5fc77349c..a044bdfe50 100644 --- a/zkevm-circuits/src/evm_circuit/table.rs +++ b/zkevm-circuits/src/evm_circuit/table.rs @@ -32,6 +32,8 @@ pub enum FixedTableTag { BitwiseOr, BitwiseXor, ResponsibleOpcode, + Bitslevel, + Pow64, } impl FixedTableTag { @@ -46,6 +48,8 @@ impl FixedTableTag { Self::BitwiseOr, Self::BitwiseXor, Self::ResponsibleOpcode, + Self::Bitslevel, + Self::Pow64, ] .iter() .copied() @@ -98,6 +102,17 @@ impl FixedTableTag { }) })) } + Self::Bitslevel => Box::new((0..9).flat_map(move |level| { + (0..(1 << level)).map(move |idx| [tag, F::from(level), F::from(idx), F::zero()]) + })), + Self::Pow64 => Box::new((0..64).map(move |idx| { + [ + tag, + F::from(idx), + F::from_u128(1u128 << idx), + F::from_u128(1u128 << (64 - idx)), + ] + })), } } } diff --git a/zkevm-circuits/src/evm_circuit/util/constraint_builder.rs b/zkevm-circuits/src/evm_circuit/util/constraint_builder.rs index 8c72f21b72..94f5161c91 100644 --- a/zkevm-circuits/src/evm_circuit/util/constraint_builder.rs +++ b/zkevm-circuits/src/evm_circuit/util/constraint_builder.rs @@ -13,7 +13,7 @@ use halo2::{arithmetic::FieldExt, plonk::Expression}; use std::convert::TryInto; // Max degree allowed in all expressions passing through the ConstraintBuilder. -const MAX_DEGREE: usize = 2usize.pow(3) + 1; +const MAX_DEGREE: usize = 2usize.pow(3) + 3; // Degree added for expressions used in lookups. const LOOKUP_DEGREE: usize = 3; diff --git a/zkevm-circuits/src/evm_circuit/util/math_gadget.rs b/zkevm-circuits/src/evm_circuit/util/math_gadget.rs index 0fbf035a55..4c94b44f2e 100644 --- a/zkevm-circuits/src/evm_circuit/util/math_gadget.rs +++ b/zkevm-circuits/src/evm_circuit/util/math_gadget.rs @@ -1,7 +1,10 @@ use crate::{ - evm_circuit::util::{ - self, constraint_builder::ConstraintBuilder, from_bytes, pow_of_two, pow_of_two_expr, - select, split_u256, sum, Cell, + evm_circuit::{ + table::{FixedTableTag, Lookup}, + util::{ + self, constraint_builder::ConstraintBuilder, from_bytes, pow_of_two, pow_of_two_expr, + select, split_u256, sum, Cell, + }, }, util::Expr, }; @@ -783,3 +786,362 @@ impl MinMaxGadget { }) } } + +// This function generates Lagrange polynomial given a cell, index, and domain +// size. The polynomial will be equal to 1 when `cell == idx`, otherwise 0. +// The value of the cell needs to be in the range [0, domain_size) +fn generate_lagrange_base_polynomial( + cell: Cell, + idx: u64, + domain_size: u64, +) -> Expression { + let mut base_poly = 1.expr(); + let mut accumulated_inverse = 1.expr(); + for x in 0..domain_size { + if x != idx { + base_poly = base_poly * (cell.expr() - x.expr()); + let inverse = if x < idx { + F::from_u128((idx - x) as u128).invert().unwrap() + } else { + -F::from_u128((x - idx) as u128).invert().unwrap() + }; + accumulated_inverse = accumulated_inverse * inverse; + } + } + base_poly * accumulated_inverse +} + +#[derive(Clone, Debug)] +pub struct ShlWordsGadget { + a: util::Word, + shift: util::Word, + b: util::Word, + // slice_front means the higher part of split digit + // slice_back means the lower part of the split digit + a_slice_front: [Cell; 32], + a_slice_back: [Cell; 32], + // shift_div64, shift_mod64_div8, shift_mod8 + // is used to seperate shift[0] + shift_div64: Cell, + shift_mod64_div8: Cell, + shift_mod64_decpow: Cell, // means 2^(8-shift_mod64) + shift_mod64_pow: Cell, // means 2^shift_mod64 + shift_mod8: Cell, + // if combination of shift[1..32] == 0 + // shift_overflow will be equal to 0, otherwise 1. + shift_overflow: Cell, + // is_zero will check combination of shift[1..32] == 0 + is_zero: IsZeroGadget, +} +impl ShlWordsGadget { + pub(crate) fn construct( + cb: &mut ConstraintBuilder, + a: util::Word, + shift: util::Word, + ) -> Self { + let b = cb.query_word(); + let a_slice_front = array_init::array_init(|_| cb.query_byte()); + let a_slice_back = array_init::array_init(|_| cb.query_byte()); + let shift_div64 = cb.query_cell(); + let shift_mod64_div8 = cb.query_cell(); + let shift_mod64_decpow = cb.query_cell(); + let shift_mod64_pow = cb.query_cell(); + let shift_mod8 = cb.query_cell(); + let shift_overflow = cb.query_bool(); + + // check (combination of shift[1..32] == 0) == 1 - shift_overflow + let mut sum = 0.expr(); + (1..32).for_each(|idx| sum = sum.clone() + shift.cells[idx].expr()); + let is_zero = IsZeroGadget::construct(cb, sum); + cb.require_equal( + "shift_overflow == shift > 256 ", + shift_overflow.expr(), + 1.expr() - is_zero.expr(), + ); + + // rename variable: + // shift_div64 :a + // shift_mod64_div8:b + // shift_mod8:c + // we split shift[0] to the equation: + // shift[0] == a * 64 + b * 8 + c + let shift_mod64 = 8.expr() * shift_mod64_div8.expr() + shift_mod8.expr(); + cb.require_equal( + "shift[0] == shift_div64 * 64 + shift_mod64_div8 * 8 + shift_mod8", + shift.cells[0].expr(), + shift_div64.expr() * 64.expr() + shift_mod64.clone(), + ); + + // merge 8 8-bit cell for a 64-bit expression + // for a, a_slice_front, a_slice_back, b + let mut a_digits = vec![]; + let mut a_slice_front_digits = vec![]; + let mut a_slice_back_digits = vec![]; + let mut b_digits = vec![]; + for virtual_idx in 0..4 { + let now_idx = (virtual_idx * 8) as usize; + a_digits.push(from_bytes::expr(&a.cells[now_idx..now_idx + 8])); + a_slice_back_digits.push(from_bytes::expr(&a_slice_back[now_idx..now_idx + 8])); + a_slice_front_digits.push(from_bytes::expr(&a_slice_front[now_idx..now_idx + 8])); + b_digits.push(from_bytes::expr(&b.cells[now_idx..now_idx + 8])); + } + + // check combination of a_slice_back_digits and a_slice_front_digits + // == b_digits + let mut shl_constraints = (0..4).map(|_| 0.expr()).collect::>>(); + for transplacement in (0_usize)..(4_usize) { + // generate the polynomial depends on the shift_div64 + let select_transplacement_polynomial = + generate_lagrange_base_polynomial(shift_div64.clone(), transplacement as u64, 4u64); + for idx in 0..(4 - transplacement) { + let tmpidx = idx + transplacement; + let merge_a = if idx == (0_usize) { + a_slice_back_digits[idx].clone() * shift_mod64_pow.expr() + } else { + a_slice_back_digits[idx].clone() * shift_mod64_pow.expr() + + a_slice_front_digits[idx - 1].clone() + }; + shl_constraints[tmpidx] = shl_constraints[tmpidx].clone() + + select_transplacement_polynomial.clone() + * select::expr( + shift_overflow.expr(), + b_digits[tmpidx].clone(), + merge_a - b_digits[tmpidx].clone(), + ); + } + for idx in 0..transplacement { + shl_constraints[idx] = shl_constraints[idx].clone() + + select_transplacement_polynomial.clone() * b_digits[idx].clone(); + } + } + (0..4).for_each(|idx| { + cb.require_zero( + "merge a_slice_back_digits and a_slice_front_digits == b_digits", + shl_constraints[idx].clone(), + ) + }); + + // for i in 0..4 + // a_slice_back_digits[i] + a_slice_front_digits * shift_mod64_decpow + // == a_digits[i] + for idx in 0..4 { + cb.require_equal( + "a[idx] == a_slice_back[idx] + a_slice_front[idx] * shift_mod64_decpow", + a_slice_back_digits[idx].clone() + + a_slice_front_digits[idx].clone() * shift_mod64_decpow.expr(), + a_digits[idx].clone(), + ); + } + + // check serveral higher cells == 0 for slice_back and slice_front + let mut equal_to_zero = 0.expr(); + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.clone(), + digit_transplacement as u64, + 8u64, + ); + for virtual_idx in 0..4 { + for idx in (digit_transplacement + 1)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_front[nowidx].expr()); + } + for idx in (8 - digit_transplacement)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_back[nowidx].expr()); + } + } + } + + //check the specific 4 cells in 0..(1 << shift_mod8). + //check another specific 4 cells in 0..(1 << (8 - shift_mod8)). + for virtual_idx in 0..4 { + let mut slice_bits_polynomial = vec![0.expr(), 0.expr()]; + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.clone(), + digit_transplacement as u64, + 8u64, + ); + let nowidx = (virtual_idx * 8 + digit_transplacement) as usize; + slice_bits_polynomial[0] = slice_bits_polynomial[0].clone() + + select_transplacement_polynomial.clone() * a_slice_front[nowidx].expr(); + let nowidx = (virtual_idx * 8 + 7 - digit_transplacement) as usize; + slice_bits_polynomial[1] = slice_bits_polynomial[1].clone() + + select_transplacement_polynomial.clone() * a_slice_back[nowidx].expr(); + } + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + shift_mod8.expr(), + slice_bits_polynomial[0].clone(), + 0.expr(), + ], + }, + ); + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + 8.expr() - shift_mod8.expr(), + slice_bits_polynomial[1].clone(), + 0.expr(), + ], + }, + ); + } + + // check: + // 2^shift_mod64 == shift_mod64_pow + // 2^(8-shift_mod64) == shift_mod64_decpow + cb.add_lookup( + "pow_of_two lookup", + Lookup::Fixed { + tag: FixedTableTag::Pow64.expr(), + values: [ + shift_mod64, + shift_mod64_pow.expr(), + shift_mod64_decpow.expr(), + ], + }, + ); + + cb.add_lookup( + "shift_div64 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [2.expr(), shift_div64.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod64_div8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod64_div8.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod8.expr(), 0.expr()], + }, + ); + + Self { + a, + shift, + b, + a_slice_front, + a_slice_back, + shift_div64, + shift_mod64_div8, + shift_mod64_decpow, + shift_mod64_pow, + shift_mod8, + shift_overflow, + is_zero, + } + } + + pub(crate) fn assign( + &self, + region: &mut Region<'_, F>, + offset: usize, + a: Word, + shift: Word, + b: Word, + ) -> Result<(), Error> { + self.assign_witness(region, offset, &a, &shift)?; + self.a.assign(region, offset, Some(a.to_le_bytes()))?; + self.shift + .assign(region, offset, Some(shift.to_le_bytes()))?; + self.b.assign(region, offset, Some(b.to_le_bytes()))?; + Ok(()) + } + + pub(crate) fn b(&self) -> &util::Word { + &self.b + } + + fn assign_witness( + &self, + region: &mut Region<'_, F>, + offset: usize, + wa: &Word, + wshift: &Word, + ) -> Result<(), Error> { + let a8s = wa.to_le_bytes(); + let shift = wshift.to_le_bytes()[0] as u128; + let shift_div64 = shift / 64; + let shift_mod64_div8 = shift % 64 / 8; + let shift_mod64 = shift % 64; + let shift_mod64_pow = 1u128 << shift_mod64; + let shift_mod64_decpow = (1u128 << 64) / (shift_mod64_pow as u128); + let shift_mod8 = shift % 8; + let mut a_slice_front = [0u8; 32]; + let mut a_slice_back = [0u8; 32]; + for virtual_idx in 0..4 { + let mut tmp_a: u64 = 0; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + tmp_a += (1u64 << (8 * idx)) * (a8s[now_idx] as u64); + } + let mut slice_back = if shift_mod64 == 0 { + tmp_a + } else { + tmp_a % (1u64 << (64 - shift_mod64)) + }; + let mut slice_front = if shift_mod64 == 0 { + 0 + } else { + tmp_a / (1u64 << (64 - shift_mod64)) + }; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + a_slice_back[now_idx] = (slice_back % (1 << 8)) as u8; + a_slice_front[now_idx] = (slice_front % (1 << 8)) as u8; + slice_back >>= 8; + slice_front >>= 8; + } + } + a_slice_front + .iter() + .zip(self.a_slice_front.iter()) + .try_for_each(|(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + })?; + a_slice_back + .iter() + .zip(self.a_slice_back.iter()) + .try_for_each(|(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + })?; + self.shift_div64 + .assign(region, offset, Some(F::from_u128(shift_div64)))?; + self.shift_mod64_div8 + .assign(region, offset, Some(F::from_u128(shift_mod64_div8)))?; + self.shift_mod64_decpow + .assign(region, offset, Some(F::from_u128(shift_mod64_decpow)))?; + self.shift_mod64_pow + .assign(region, offset, Some(F::from_u128(shift_mod64_pow)))?; + self.shift_mod8 + .assign(region, offset, Some(F::from_u128(shift_mod8)))?; + + let mut sum: u128 = 0; + wshift.to_le_bytes().iter().for_each(|v| sum += *v as u128); + sum -= shift as u128; + let shift_overflow = sum != 0; + self.is_zero.assign(region, offset, F::from_u128(sum))?; + self.shift_overflow + .assign(region, offset, Some(F::from_u128(shift_overflow as u128)))?; + Ok(()) + } +} diff --git a/zkevm-circuits/src/evm_circuit/witness.rs b/zkevm-circuits/src/evm_circuit/witness.rs index 1b4af57a40..a22df44943 100644 --- a/zkevm-circuits/src/evm_circuit/witness.rs +++ b/zkevm-circuits/src/evm_circuit/witness.rs @@ -577,6 +577,7 @@ impl From<&bus_mapping::circuit_input_builder::ExecStep> for ExecutionState { OpcodeId::SUB => ExecutionState::ADD, OpcodeId::EQ | OpcodeId::LT | OpcodeId::GT => ExecutionState::CMP, OpcodeId::SLT | OpcodeId::SGT => ExecutionState::SCMP, + OpcodeId::SHL => ExecutionState::SHL, OpcodeId::SIGNEXTEND => ExecutionState::SIGNEXTEND, OpcodeId::STOP => ExecutionState::STOP, OpcodeId::AND => ExecutionState::BITWISE, diff --git a/zkevm-circuits/src/test_util.rs b/zkevm-circuits/src/test_util.rs index c1b9b8332e..be330101dd 100644 --- a/zkevm-circuits/src/test_util.rs +++ b/zkevm-circuits/src/test_util.rs @@ -22,6 +22,8 @@ pub fn get_fixed_table(conf: FixedTableConfig) -> Vec { FixedTableTag::Range512, FixedTableTag::SignByte, FixedTableTag::ResponsibleOpcode, + FixedTableTag::Bitslevel, + FixedTableTag::Pow64, ] } FixedTableConfig::Complete => FixedTableTag::iterator().collect(),