diff --git a/zkevm-circuits/src/evm_circuit.rs b/zkevm-circuits/src/evm_circuit.rs index 62e47ec712..e149638a00 100644 --- a/zkevm-circuits/src/evm_circuit.rs +++ b/zkevm-circuits/src/evm_circuit.rs @@ -493,6 +493,8 @@ pub mod test { FixedTableTag::Range1024, FixedTableTag::SignByte, FixedTableTag::ResponsibleOpcode, + FixedTableTag::Bitslevel, + FixedTableTag::Pow64, ], ) } diff --git a/zkevm-circuits/src/evm_circuit/execution.rs b/zkevm-circuits/src/evm_circuit/execution.rs index c63266a019..aa72e9cf70 100644 --- a/zkevm-circuits/src/evm_circuit/execution.rs +++ b/zkevm-circuits/src/evm_circuit/execution.rs @@ -58,6 +58,8 @@ mod pc; mod pop; mod push; mod selfbalance; +mod shl; +mod shr; mod signed_comparator; mod signextend; mod sload; @@ -102,6 +104,8 @@ use pc::PcGadget; use pop::PopGadget; use push::PushGadget; use selfbalance::SelfbalanceGadget; +use shl::ShlGadget; +use shr::ShrGadget; use signed_comparator::SignedComparatorGadget; use signextend::SignextendGadget; use sload::SloadGadget; @@ -176,6 +180,8 @@ pub(crate) struct ExecutionConfig { pop_gadget: PopGadget, push_gadget: PushGadget, selfbalance_gadget: SelfbalanceGadget, + shl_gadget: ShlGadget, + shr_gadget: ShrGadget, signed_comparator_gadget: SignedComparatorGadget, signextend_gadget: SignextendGadget, sload_gadget: SloadGadget, @@ -365,6 +371,8 @@ impl ExecutionConfig { pop_gadget: configure_gadget!(), push_gadget: configure_gadget!(), selfbalance_gadget: configure_gadget!(), + shl_gadget: configure_gadget!(), + shr_gadget: configure_gadget!(), signed_comparator_gadget: configure_gadget!(), signextend_gadget: configure_gadget!(), sload_gadget: configure_gadget!(), @@ -805,6 +813,8 @@ impl ExecutionConfig { ExecutionState::SELFBALANCE => assign_exec_step!(self.selfbalance_gadget), ExecutionState::SIGNEXTEND => assign_exec_step!(self.signextend_gadget), ExecutionState::SLOAD => assign_exec_step!(self.sload_gadget), + ExecutionState::SHL => assign_exec_step!(self.shl_gadget), + ExecutionState::SHR => assign_exec_step!(self.shr_gadget), ExecutionState::SSTORE => assign_exec_step!(self.sstore_gadget), ExecutionState::STOP => assign_exec_step!(self.stop_gadget), ExecutionState::SWAP => assign_exec_step!(self.swap_gadget), diff --git a/zkevm-circuits/src/evm_circuit/execution/shl.rs b/zkevm-circuits/src/evm_circuit/execution/shl.rs new file mode 100644 index 0000000000..836bc8e08b --- /dev/null +++ b/zkevm-circuits/src/evm_circuit/execution/shl.rs @@ -0,0 +1,125 @@ +use crate::{ + evm_circuit::{ + execution::ExecutionGadget, + step::ExecutionState, + util::{ + common_gadget::SameContextGadget, + constraint_builder::{ConstraintBuilder, StepStateTransition, Transition::Delta}, + math_gadget::ShlWordsGadget, + }, + witness::{Block, Call, ExecStep, Transaction}, + }, + util::Expr, +}; + +use bus_mapping::evm::OpcodeId; +use eth_types::Field; +use halo2_proofs::{circuit::Region, plonk::Error}; + +#[derive(Clone, Debug)] +pub(crate) struct ShlGadget { + same_context: SameContextGadget, + shl_words: ShlWordsGadget, +} + +impl ExecutionGadget for ShlGadget { + const NAME: &'static str = "SHL"; + + const EXECUTION_STATE: ExecutionState = ExecutionState::SHL; + + fn configure(cb: &mut ConstraintBuilder) -> Self { + let opcode = cb.query_cell(); + + let a = cb.query_word(); + let shift = cb.query_word(); + + cb.stack_pop(shift.expr()); + cb.stack_pop(a.expr()); + let shl_words = ShlWordsGadget::construct(cb, a, shift); + cb.stack_push(shl_words.b().expr()); + + let step_state_transition = StepStateTransition { + rw_counter: Delta(3.expr()), + program_counter: Delta(1.expr()), + stack_pointer: Delta(1.expr()), + gas_left: Delta(-OpcodeId::SHL.constant_gas_cost().expr()), + ..Default::default() + }; + let same_context = SameContextGadget::construct(cb, opcode, step_state_transition); + + Self { + same_context, + shl_words, + } + } + + fn assign_exec_step( + &self, + region: &mut Region<'_, F>, + offset: usize, + block: &Block, + _: &Transaction, + _: &Call, + step: &ExecStep, + ) -> Result<(), Error> { + self.same_context.assign_exec_step(region, offset, step)?; + let indices = [step.rw_indices[0], step.rw_indices[1], step.rw_indices[2]]; + let [shift, a, b] = indices.map(|idx| block.rws[idx].stack_value()); + self.shl_words.assign(region, offset, a, shift, b) + } +} + +#[cfg(test)] +mod test { + use crate::evm_circuit::test::rand_word; + use crate::test_util::run_test_circuits; + use eth_types::evm_types::OpcodeId; + use eth_types::{bytecode, Word}; + use mock::TestContext; + use rand::Rng; + + fn test_ok(opcode: OpcodeId, a: Word, shift: Word) { + let bytecode = bytecode! { + PUSH32(a) + PUSH32(shift) + #[start] + .write_op(opcode) + STOP + }; + assert_eq!( + run_test_circuits( + TestContext::<2, 1>::simple_ctx_with_bytecode(bytecode).unwrap(), + None + ), + Ok(()) + ); + } + + #[test] + fn shl_gadget_simple() { + test_ok(OpcodeId::SHL, 0x02FF.into(), 0x1.into()); + } + + #[test] + fn shl_gadget_rand_normal_shift() { + let a = rand_word(); + let mut rng = rand::thread_rng(); + let shift = rng.gen_range(0..=255); + test_ok(OpcodeId::SHL, a, shift.into()); + } + + #[test] + fn shl_gadget_rand_overflow_shift() { + let a = rand_word(); + let shift = Word::from_big_endian(&[255u8; 32]); + test_ok(OpcodeId::SHL, a, shift); + } + + //this testcase manage to check the split is correct. + #[test] + fn shl_gadget_constant_shift() { + let a = rand_word(); + test_ok(OpcodeId::SHL, a, 8.into()); + test_ok(OpcodeId::SHL, a, 64.into()); + } +} diff --git a/zkevm-circuits/src/evm_circuit/execution/shr.rs b/zkevm-circuits/src/evm_circuit/execution/shr.rs new file mode 100644 index 0000000000..42522d07b6 --- /dev/null +++ b/zkevm-circuits/src/evm_circuit/execution/shr.rs @@ -0,0 +1,124 @@ +use crate::{ + evm_circuit::{ + execution::ExecutionGadget, + step::ExecutionState, + util::{ + common_gadget::SameContextGadget, + constraint_builder::{ConstraintBuilder, StepStateTransition, Transition::Delta}, + math_gadget::ShrWordsGadget, + }, + witness::{Block, Call, ExecStep, Transaction}, + }, + util::Expr, +}; +use bus_mapping::evm::OpcodeId; +use eth_types::Field; +use halo2_proofs::{circuit::Region, plonk::Error}; + +#[derive(Clone, Debug)] +pub(crate) struct ShrGadget { + same_context: SameContextGadget, + shr_words: ShrWordsGadget, +} + +impl ExecutionGadget for ShrGadget { + const NAME: &'static str = "SHR"; + + const EXECUTION_STATE: ExecutionState = ExecutionState::SHR; + + fn configure(cb: &mut ConstraintBuilder) -> Self { + let opcode = cb.query_cell(); + + let a = cb.query_word(); + let shift = cb.query_word(); + + cb.stack_pop(shift.expr()); + cb.stack_pop(a.expr()); + let shr_words = ShrWordsGadget::construct(cb, a, shift); + cb.stack_push(shr_words.b().expr()); + + let step_state_transition = StepStateTransition { + rw_counter: Delta(3.expr()), + program_counter: Delta(1.expr()), + stack_pointer: Delta(1.expr()), + gas_left: Delta(-OpcodeId::SHR.constant_gas_cost().expr()), + ..Default::default() + }; + let same_context = SameContextGadget::construct(cb, opcode, step_state_transition); + + Self { + same_context, + shr_words, + } + } + + fn assign_exec_step( + &self, + region: &mut Region<'_, F>, + offset: usize, + block: &Block, + _: &Transaction, + _: &Call, + step: &ExecStep, + ) -> Result<(), Error> { + self.same_context.assign_exec_step(region, offset, step)?; + let indices = [step.rw_indices[0], step.rw_indices[1], step.rw_indices[2]]; + let [shift, a, b] = indices.map(|idx| block.rws[idx].stack_value()); + self.shr_words.assign(region, offset, a, shift, b) + } +} + +#[cfg(test)] +mod test { + use crate::evm_circuit::test::rand_word; + use crate::test_util::run_test_circuits; + use eth_types::evm_types::OpcodeId; + use eth_types::{bytecode, Word}; + use mock::TestContext; + use rand::Rng; + + fn test_ok(opcode: OpcodeId, a: Word, shift: Word) { + let bytecode = bytecode! { + PUSH32(a) + PUSH32(shift) + #[start] + .write_op(opcode) + STOP + }; + assert_eq!( + run_test_circuits( + TestContext::<2, 1>::simple_ctx_with_bytecode(bytecode).unwrap(), + None + ), + Ok(()) + ); + } + + #[test] + fn shr_gadget_simple() { + test_ok(OpcodeId::SHR, 0x02FF.into(), 0x1.into()); + } + + #[test] + fn shr_gadget_rand_normal_shift() { + let a = rand_word(); + let mut rng = rand::thread_rng(); + let shift = rng.gen_range(0..=255); + test_ok(OpcodeId::SHR, a, shift.into()); + } + + #[test] + fn shr_gadget_rand_overflow_shift() { + let a = rand_word(); + let shift = Word::from_big_endian(&[255u8; 32]); + test_ok(OpcodeId::SHR, a, shift); + } + + //this testcase manage to check the split is correct. + #[test] + fn shr_gadget_constant_shift() { + let a = rand_word(); + test_ok(OpcodeId::SHR, a, 8.into()); + test_ok(OpcodeId::SHR, a, 64.into()); + } +} diff --git a/zkevm-circuits/src/evm_circuit/table.rs b/zkevm-circuits/src/evm_circuit/table.rs index 11ad21eb45..8f5ec2c9b9 100644 --- a/zkevm-circuits/src/evm_circuit/table.rs +++ b/zkevm-circuits/src/evm_circuit/table.rs @@ -40,6 +40,8 @@ pub enum FixedTableTag { BitwiseOr, BitwiseXor, ResponsibleOpcode, + Bitslevel, + Pow64, } impl FixedTableTag { @@ -58,6 +60,8 @@ impl FixedTableTag { Self::BitwiseOr, Self::BitwiseXor, Self::ResponsibleOpcode, + Self::Bitslevel, + Self::Pow64, ] .iter() .copied() @@ -120,6 +124,17 @@ impl FixedTableTag { }) })) } + Self::Bitslevel => Box::new((0..9).flat_map(move |level| { + (0..(1 << level)).map(move |idx| [tag, F::from(level), F::from(idx), F::zero()]) + })), + Self::Pow64 => Box::new((0..64).map(move |idx| { + [ + tag, + F::from(idx), + F::from_u128(1u128 << idx), + F::from_u128(1u128 << (64 - idx)), + ] + })), } } } diff --git a/zkevm-circuits/src/evm_circuit/util/math_gadget.rs b/zkevm-circuits/src/evm_circuit/util/math_gadget.rs index 8c1136c593..c1899cd122 100644 --- a/zkevm-circuits/src/evm_circuit/util/math_gadget.rs +++ b/zkevm-circuits/src/evm_circuit/util/math_gadget.rs @@ -1,8 +1,11 @@ use super::CachedRegion; use crate::{ - evm_circuit::util::{ - self, constraint_builder::ConstraintBuilder, from_bytes, pow_of_two, pow_of_two_expr, - select, split_u256, split_u256_limb64, sum, Cell, + evm_circuit::{ + table::{FixedTableTag, Lookup}, + util::{ + self, constraint_builder::ConstraintBuilder, from_bytes, pow_of_two, pow_of_two_expr, + select, split_u256, split_u256_limb64, sum, Cell, + }, }, util::Expr, }; @@ -729,11 +732,11 @@ impl MinMaxGadget { pub(crate) fn generate_lagrange_base_polynomial< F: Field, Exp: Expr, - R: Iterator, + I: Iterator, >( exp: Exp, val: usize, - range: R, + range: I, ) -> Expression { let mut numerator = 1u64.expr(); let mut denominator = F::from(1); @@ -896,3 +899,667 @@ impl MulAddWordsGadget { self.overflow.clone() } } + +#[derive(Clone, Debug)] +pub struct ShrWordsGadget { + a: util::Word, + shift: util::Word, + b: util::Word, + // slice_hi means the higher part of split digit + // slice_lo means the lower part of the split digit + a_slice_hi: [Cell; 32], + a_slice_lo: [Cell; 32], + // shift_div64, shift_mod64_div8, shift_mod8 + // is used to seperate shift[0] + shift_div64: Cell, + shift_mod64_div8: Cell, + shift_mod64_decpow: Cell, // means 2^(8-shift_mod64) + shift_mod64_pow: Cell, // means 2^shift_mod64 + shift_mod8: Cell, + // is_zero will check combination of shift[1..32] == 0 + is_zero: IsZeroGadget, +} + +impl ShrWordsGadget { + pub(crate) fn construct( + cb: &mut ConstraintBuilder, + a: util::Word, + shift: util::Word, + ) -> Self { + let b = cb.query_word(); + let a_slice_hi = cb.query_bytes(); + let a_slice_lo = cb.query_bytes(); + let shift_div64 = cb.query_cell(); + let shift_mod64_div8 = cb.query_cell(); + let shift_mod64_decpow = cb.query_cell(); + let shift_mod64_pow = cb.query_cell(); + let shift_mod8 = cb.query_cell(); + + // check (combination of shift[1..32] == 0) == 1 - shift_overflow + let mut sum = 0.expr(); + (1..32).for_each(|idx| sum = sum.clone() + shift.cells[idx].expr()); + let is_zero = IsZeroGadget::construct(cb, sum); + // if combination of shift[1..32] == 0 + // shift_overflow will be equal to 0, otherwise 1. + let shift_overflow = 1.expr() - is_zero.expr(); + cb.require_equal( + "shift_overflow == shift > 256 ", + shift_overflow.clone(), + 1.expr() - is_zero.expr(), + ); + + // rename variable: + // shift_div64 :a + // shift_mod64_div8:b + // shift_mod8:c + // we split shift[0] to the equation: + // shift[0] == a * 64 + b * 8 + c + let shift_mod64 = 8.expr() * shift_mod64_div8.expr() + shift_mod8.expr(); + cb.require_equal( + "shift[0] == shift_div64 * 64 + shift_mod64_div8 * 8 + shift_mod8", + shift.cells[0].expr(), + shift_div64.expr() * 64.expr() + shift_mod64.clone(), + ); + + // merge 8 8-bit cell for a 64-bit expression + // for a, a_slice_hi, a_slice_lo, b + let mut a_digits = vec![]; + let mut a_slice_hi_digits = vec![]; + let mut a_slice_lo_digits = vec![]; + let mut b_digits = vec![]; + for virtual_idx in 0..4 { + let now_idx = (virtual_idx * 8) as usize; + a_digits.push(from_bytes::expr(&a.cells[now_idx..now_idx + 8])); + a_slice_lo_digits.push(from_bytes::expr(&a_slice_lo[now_idx..now_idx + 8])); + a_slice_hi_digits.push(from_bytes::expr(&a_slice_hi[now_idx..now_idx + 8])); + b_digits.push(from_bytes::expr(&b.cells[now_idx..now_idx + 8])); + } + + // check combination of a_slice_back_digits and a_slice_front_digits + // == b_digits + let mut shr_constraints = (0..4).map(|_| 0.expr()).collect::>>(); + for transplacement in (0_usize)..(4_usize) { + // generate the polynomial depends on the shift_div64 + let select_transplacement_polynomial = + generate_lagrange_base_polynomial(shift_div64.expr(), transplacement, 0..4); + for idx in 0..(4 - transplacement) { + let tmpidx = idx + transplacement; + let merge_a = if idx + transplacement == (3_usize) { + a_slice_hi_digits[tmpidx].clone() + } else { + a_slice_hi_digits[tmpidx].clone() + + a_slice_lo_digits[tmpidx + 1].clone() * shift_mod64_decpow.expr() + }; + shr_constraints[idx] = shr_constraints[idx].clone() + + select_transplacement_polynomial.clone() + * select::expr( + shift_overflow.clone(), + b_digits[idx].clone(), + merge_a - b_digits[idx].clone(), + ); + } + for idx in (4 - transplacement)..4 { + shr_constraints[idx] = shr_constraints[idx].clone() + + select_transplacement_polynomial.clone() * b_digits[idx].clone(); + } + } + (0..4).for_each(|idx| { + cb.require_zero( + "merge a_slice_lo_digits and a_slice_hi_digits == b_digits", + shr_constraints[idx].clone(), + ) + }); + + // for i in 0..4 + // a_slice_lo_digits[i] + a_slice_hi_digits * shift_mod64_pow + // == a_digits[i] + for idx in 0..4 { + cb.require_equal( + "a[idx] == a_slice_lo[idx] + a_slice_hi[idx] * shift_mod64_pow", + a_slice_lo_digits[idx].clone() + + a_slice_hi_digits[idx].clone() * shift_mod64_pow.expr(), + a_digits[idx].clone(), + ); + } + + // check serveral higher cells == 0 for slice_back and slice_front + let mut equal_to_zero = 0.expr(); + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.expr(), + digit_transplacement, + 0..8, + ); + for virtual_idx in 0..4 { + for idx in (digit_transplacement + 1)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_lo[nowidx].expr()); + } + for idx in (8 - digit_transplacement)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_hi[nowidx].expr()); + } + } + } + + //check the specific 4 cells in 0..(1 << shift_mod8). + //check another specific 4 cells in 0..(1 << (8 - shift_mod8)). + for virtual_idx in 0..4 { + let mut slice_bits_polynomial = vec![0.expr(), 0.expr()]; + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.expr(), + digit_transplacement, + 0..8, + ); + let nowidx = (virtual_idx * 8 + digit_transplacement) as usize; + slice_bits_polynomial[0] = slice_bits_polynomial[0].clone() + + select_transplacement_polynomial.clone() * a_slice_lo[nowidx].expr(); + let nowidx = (virtual_idx * 8 + 7 - digit_transplacement) as usize; + slice_bits_polynomial[1] = slice_bits_polynomial[1].clone() + + select_transplacement_polynomial.clone() * a_slice_hi[nowidx].expr(); + } + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + shift_mod8.expr(), + slice_bits_polynomial[0].clone(), + 0.expr(), + ], + }, + ); + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + 8.expr() - shift_mod8.expr(), + slice_bits_polynomial[1].clone(), + 0.expr(), + ], + }, + ); + } + + // check: + // 2^shift_mod64 == shift_mod64_pow + // 2^(8-shift_mod64) == shift_mod64_decpow + cb.add_lookup( + "pow_of_two lookup", + Lookup::Fixed { + tag: FixedTableTag::Pow64.expr(), + values: [ + shift_mod64, + shift_mod64_pow.expr(), + shift_mod64_decpow.expr(), + ], + }, + ); + + cb.add_lookup( + "shift_div64 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [2.expr(), shift_div64.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod64_div8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod64_div8.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod8.expr(), 0.expr()], + }, + ); + + Self { + a, + shift, + b, + a_slice_hi, + a_slice_lo, + shift_div64, + shift_mod64_div8, + shift_mod64_decpow, + shift_mod64_pow, + shift_mod8, + is_zero, + } + } + + pub(crate) fn assign( + &self, + region: &mut Region<'_, F>, + offset: usize, + a: Word, + shift: Word, + b: Word, + ) -> Result<(), Error> { + self.assign_witness(region, offset, &a, &shift)?; + self.a.assign(region, offset, Some(a.to_le_bytes()))?; + self.shift + .assign(region, offset, Some(shift.to_le_bytes()))?; + self.b.assign(region, offset, Some(b.to_le_bytes()))?; + Ok(()) + } + + pub(crate) fn b(&self) -> &util::Word { + &self.b + } + + fn assign_witness( + &self, + region: &mut Region<'_, F>, + offset: usize, + wa: &Word, + wshift: &Word, + ) -> Result<(), Error> { + let a8s = wa.to_le_bytes(); + let shift = wshift.to_le_bytes()[0] as u128; + let shift_div64 = shift / 64; + let shift_mod64_div8 = shift % 64 / 8; + let shift_mod64 = shift % 64; + let shift_mod64_pow = 1u128 << shift_mod64; + let shift_mod64_decpow = (1u128 << 64) / (shift_mod64_pow as u128); + let shift_mod8 = shift % 8; + let mut a_slice_hi = [0u8; 32]; + let mut a_slice_lo = [0u8; 32]; + for virtual_idx in 0..4 { + let mut tmp_a: u64 = 0; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + tmp_a += (1u64 << (8 * idx)) * (a8s[now_idx] as u64); + } + let mut slice_back = if shift_mod64 == 0 { + 0 + } else { + tmp_a % (1u64 << shift_mod64) + }; + let mut slice_front = if shift_mod64 == 0 { + tmp_a + } else { + tmp_a / (1u64 << shift_mod64) + }; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + a_slice_lo[now_idx] = (slice_back % (1 << 8)) as u8; + a_slice_hi[now_idx] = (slice_front % (1 << 8)) as u8; + slice_back >>= 8; + slice_front >>= 8; + } + } + a_slice_hi.iter().zip(self.a_slice_hi.iter()).try_for_each( + |(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + }, + )?; + a_slice_lo.iter().zip(self.a_slice_lo.iter()).try_for_each( + |(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + }, + )?; + self.shift_div64 + .assign(region, offset, Some(F::from_u128(shift_div64)))?; + self.shift_mod64_div8 + .assign(region, offset, Some(F::from_u128(shift_mod64_div8)))?; + self.shift_mod64_decpow + .assign(region, offset, Some(F::from_u128(shift_mod64_decpow)))?; + self.shift_mod64_pow + .assign(region, offset, Some(F::from_u128(shift_mod64_pow)))?; + self.shift_mod8 + .assign(region, offset, Some(F::from_u128(shift_mod8)))?; + + let mut sum: u128 = 0; + wshift.to_le_bytes().iter().for_each(|v| sum += *v as u128); + sum -= shift as u128; + self.is_zero.assign(region, offset, F::from_u128(sum))?; + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct ShlWordsGadget { + a: util::Word, + shift: util::Word, + b: util::Word, + // slice_front means the higher part of split digit + // slice_back means the lower part of the split digit + a_slice_front: [Cell; 32], + a_slice_back: [Cell; 32], + // shift_div64, shift_mod64_div8, shift_mod8 + // is used to seperate shift[0] + shift_div64: Cell, + shift_mod64_div8: Cell, + shift_mod64_decpow: Cell, // means 2^(8-shift_mod64) + shift_mod64_pow: Cell, // means 2^shift_mod64 + shift_mod8: Cell, + // if combination of shift[1..32] == 0 + // shift_overflow will be equal to 0, otherwise 1. + shift_overflow: Cell, + // is_zero will check combination of shift[1..32] == 0 + is_zero: IsZeroGadget, +} +impl ShlWordsGadget { + pub(crate) fn construct( + cb: &mut ConstraintBuilder, + a: util::Word, + shift: util::Word, + ) -> Self { + let b = cb.query_word(); + let a_slice_front = array_init::array_init(|_| cb.query_byte()); + let a_slice_back = array_init::array_init(|_| cb.query_byte()); + let shift_div64 = cb.query_cell(); + let shift_mod64_div8 = cb.query_cell(); + let shift_mod64_decpow = cb.query_cell(); + let shift_mod64_pow = cb.query_cell(); + let shift_mod8 = cb.query_cell(); + let shift_overflow = cb.query_bool(); + + // check (combination of shift[1..32] == 0) == 1 - shift_overflow + let mut sum = 0.expr(); + (1..32).for_each(|idx| sum = sum.clone() + shift.cells[idx].expr()); + let is_zero = IsZeroGadget::construct(cb, sum); + cb.require_equal( + "shift_overflow == shift > 256 ", + shift_overflow.expr(), + 1.expr() - is_zero.expr(), + ); + + // rename variable: + // shift_div64 :a + // shift_mod64_div8:b + // shift_mod8:c + // we split shift[0] to the equation: + // shift[0] == a * 64 + b * 8 + c + let shift_mod64 = 8.expr() * shift_mod64_div8.expr() + shift_mod8.expr(); + cb.require_equal( + "shift[0] == shift_div64 * 64 + shift_mod64_div8 * 8 + shift_mod8", + shift.cells[0].expr(), + shift_div64.expr() * 64.expr() + shift_mod64.clone(), + ); + + // merge 8 8-bit cell for a 64-bit expression + // for a, a_slice_front, a_slice_back, b + let mut a_digits = vec![]; + let mut a_slice_front_digits = vec![]; + let mut a_slice_back_digits = vec![]; + let mut b_digits = vec![]; + for virtual_idx in 0..4 { + let now_idx = (virtual_idx * 8) as usize; + a_digits.push(from_bytes::expr(&a.cells[now_idx..now_idx + 8])); + a_slice_back_digits.push(from_bytes::expr(&a_slice_back[now_idx..now_idx + 8])); + a_slice_front_digits.push(from_bytes::expr(&a_slice_front[now_idx..now_idx + 8])); + b_digits.push(from_bytes::expr(&b.cells[now_idx..now_idx + 8])); + } + + // check combination of a_slice_back_digits and a_slice_front_digits + // == b_digits + let mut shl_constraints = (0..4).map(|_| 0.expr()).collect::>>(); + for transplacement in (0_usize)..(4_usize) { + // generate the polynomial depends on the shift_div64 + let select_transplacement_polynomial = + generate_lagrange_base_polynomial(shift_div64.expr(), transplacement, 0..4); + for idx in 0..(4 - transplacement) { + let tmpidx = idx + transplacement; + let merge_a = if idx == (0_usize) { + a_slice_back_digits[idx].clone() * shift_mod64_pow.expr() + } else { + a_slice_back_digits[idx].clone() * shift_mod64_pow.expr() + + a_slice_front_digits[idx - 1].clone() + }; + shl_constraints[tmpidx] = shl_constraints[tmpidx].clone() + + select_transplacement_polynomial.clone() + * select::expr( + shift_overflow.expr(), + b_digits[tmpidx].clone(), + merge_a - b_digits[tmpidx].clone(), + ); + } + for idx in 0..transplacement { + shl_constraints[idx] = shl_constraints[idx].clone() + + select_transplacement_polynomial.clone() * b_digits[idx].clone(); + } + } + (0..4).for_each(|idx| { + cb.require_zero( + "merge a_slice_back_digits and a_slice_front_digits == b_digits", + shl_constraints[idx].clone(), + ) + }); + + // for i in 0..4 + // a_slice_back_digits[i] + a_slice_front_digits * shift_mod64_decpow + // == a_digits[i] + for idx in 0..4 { + cb.require_equal( + "a[idx] == a_slice_back[idx] + a_slice_front[idx] * shift_mod64_decpow", + a_slice_back_digits[idx].clone() + + a_slice_front_digits[idx].clone() * shift_mod64_decpow.expr(), + a_digits[idx].clone(), + ); + } + + // check serveral higher cells == 0 for slice_back and slice_front + let mut equal_to_zero = 0.expr(); + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.clone(), + digit_transplacement, + 0..8, + ); + for virtual_idx in 0..4 { + for idx in (digit_transplacement + 1)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_front[nowidx].expr()); + } + for idx in (8 - digit_transplacement)..8 { + let nowidx = (virtual_idx * 8 + idx) as usize; + equal_to_zero = equal_to_zero + + (select_transplacement_polynomial.clone() * a_slice_back[nowidx].expr()); + } + } + } + + //check the specific 4 cells in 0..(1 << shift_mod8). + //check another specific 4 cells in 0..(1 << (8 - shift_mod8)). + for virtual_idx in 0..4 { + let mut slice_bits_polynomial = vec![0.expr(), 0.expr()]; + for digit_transplacement in 0..8 { + let select_transplacement_polynomial = generate_lagrange_base_polynomial( + shift_mod64_div8.clone(), + digit_transplacement, + 0..8, + ); + let nowidx = (virtual_idx * 8 + digit_transplacement) as usize; + slice_bits_polynomial[0] = slice_bits_polynomial[0].clone() + + select_transplacement_polynomial.clone() * a_slice_front[nowidx].expr(); + let nowidx = (virtual_idx * 8 + 7 - digit_transplacement) as usize; + slice_bits_polynomial[1] = slice_bits_polynomial[1].clone() + + select_transplacement_polynomial.clone() * a_slice_back[nowidx].expr(); + } + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + shift_mod8.expr(), + slice_bits_polynomial[0].clone(), + 0.expr(), + ], + }, + ); + cb.add_lookup( + "slice_bits range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [ + 8.expr() - shift_mod8.expr(), + slice_bits_polynomial[1].clone(), + 0.expr(), + ], + }, + ); + } + + // check: + // 2^shift_mod64 == shift_mod64_pow + // 2^(8-shift_mod64) == shift_mod64_decpow + cb.add_lookup( + "pow_of_two lookup", + Lookup::Fixed { + tag: FixedTableTag::Pow64.expr(), + values: [ + shift_mod64, + shift_mod64_pow.expr(), + shift_mod64_decpow.expr(), + ], + }, + ); + + cb.add_lookup( + "shift_div64 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [2.expr(), shift_div64.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod64_div8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod64_div8.expr(), 0.expr()], + }, + ); + cb.add_lookup( + "shift_mod8 range lookup", + Lookup::Fixed { + tag: FixedTableTag::Bitslevel.expr(), + values: [3.expr(), shift_mod8.expr(), 0.expr()], + }, + ); + + Self { + a, + shift, + b, + a_slice_front, + a_slice_back, + shift_div64, + shift_mod64_div8, + shift_mod64_decpow, + shift_mod64_pow, + shift_mod8, + shift_overflow, + is_zero, + } + } + + pub(crate) fn assign( + &self, + region: &mut Region<'_, F>, + offset: usize, + a: Word, + shift: Word, + b: Word, + ) -> Result<(), Error> { + self.assign_witness(region, offset, &a, &shift)?; + self.a.assign(region, offset, Some(a.to_le_bytes()))?; + self.shift + .assign(region, offset, Some(shift.to_le_bytes()))?; + self.b.assign(region, offset, Some(b.to_le_bytes()))?; + Ok(()) + } + + pub(crate) fn b(&self) -> &util::Word { + &self.b + } + + fn assign_witness( + &self, + region: &mut Region<'_, F>, + offset: usize, + wa: &Word, + wshift: &Word, + ) -> Result<(), Error> { + let a8s = wa.to_le_bytes(); + let shift = wshift.to_le_bytes()[0] as u128; + let shift_div64 = shift / 64; + let shift_mod64_div8 = shift % 64 / 8; + let shift_mod64 = shift % 64; + let shift_mod64_pow = 1u128 << shift_mod64; + let shift_mod64_decpow = (1u128 << 64) / (shift_mod64_pow as u128); + let shift_mod8 = shift % 8; + let mut a_slice_front = [0u8; 32]; + let mut a_slice_back = [0u8; 32]; + for virtual_idx in 0..4 { + let mut tmp_a: u64 = 0; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + tmp_a += (1u64 << (8 * idx)) * (a8s[now_idx] as u64); + } + let mut slice_back = if shift_mod64 == 0 { + tmp_a + } else { + tmp_a % (1u64 << (64 - shift_mod64)) + }; + let mut slice_front = if shift_mod64 == 0 { + 0 + } else { + tmp_a / (1u64 << (64 - shift_mod64)) + }; + for idx in 0..8 { + let now_idx = virtual_idx * 8 + idx; + a_slice_back[now_idx] = (slice_back % (1 << 8)) as u8; + a_slice_front[now_idx] = (slice_front % (1 << 8)) as u8; + slice_back >>= 8; + slice_front >>= 8; + } + } + a_slice_front + .iter() + .zip(self.a_slice_front.iter()) + .try_for_each(|(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + })?; + a_slice_back + .iter() + .zip(self.a_slice_back.iter()) + .try_for_each(|(bt, assignee)| -> Result<(), Error> { + assignee.assign(region, offset, Some(F::from(*bt as u64)))?; + Ok(()) + })?; + self.shift_div64 + .assign(region, offset, Some(F::from_u128(shift_div64)))?; + self.shift_mod64_div8 + .assign(region, offset, Some(F::from_u128(shift_mod64_div8)))?; + self.shift_mod64_decpow + .assign(region, offset, Some(F::from_u128(shift_mod64_decpow)))?; + self.shift_mod64_pow + .assign(region, offset, Some(F::from_u128(shift_mod64_pow)))?; + self.shift_mod8 + .assign(region, offset, Some(F::from_u128(shift_mod8)))?; + + let mut sum: u128 = 0; + wshift.to_le_bytes().iter().for_each(|v| sum += *v as u128); + sum -= shift as u128; + let shift_overflow = sum != 0; + self.is_zero.assign(region, offset, F::from_u128(sum))?; + self.shift_overflow + .assign(region, offset, Some(F::from_u128(shift_overflow as u128)))?; + Ok(()) + } +} diff --git a/zkevm-circuits/src/evm_circuit/witness.rs b/zkevm-circuits/src/evm_circuit/witness.rs index b73f2d665f..9bbaae615e 100644 --- a/zkevm-circuits/src/evm_circuit/witness.rs +++ b/zkevm-circuits/src/evm_circuit/witness.rs @@ -1196,7 +1196,9 @@ impl From<&circuit_input_builder::ExecStep> for ExecutionState { OpcodeId::ADD | OpcodeId::SUB => ExecutionState::ADD_SUB, OpcodeId::MUL | OpcodeId::DIV | OpcodeId::MOD => ExecutionState::MUL_DIV_MOD, OpcodeId::EQ | OpcodeId::LT | OpcodeId::GT => ExecutionState::CMP, + OpcodeId::SHR => ExecutionState::SHR, OpcodeId::SLT | OpcodeId::SGT => ExecutionState::SCMP, + OpcodeId::SHL => ExecutionState::SHL, OpcodeId::SIGNEXTEND => ExecutionState::SIGNEXTEND, // TODO: Convert REVERT and RETURN to their own ExecutionState. OpcodeId::STOP | OpcodeId::RETURN | OpcodeId::REVERT => ExecutionState::STOP, diff --git a/zkevm-circuits/src/test_util.rs b/zkevm-circuits/src/test_util.rs index cbcd2a9852..fa459b2a7c 100644 --- a/zkevm-circuits/src/test_util.rs +++ b/zkevm-circuits/src/test_util.rs @@ -33,6 +33,8 @@ pub fn get_fixed_table(conf: FixedTableConfig) -> Vec { FixedTableTag::Range1024, FixedTableTag::SignByte, FixedTableTag::ResponsibleOpcode, + FixedTableTag::Bitslevel, + FixedTableTag::Pow64, ] } FixedTableConfig::Complete => FixedTableTag::iterator().collect(),