From 9b1d145558fb2138fd56c5e5f5c7be695192209c Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Tue, 17 May 2022 09:37:25 +0800 Subject: [PATCH 1/4] Add test case for `SHL`. --- src/zkevm_specs/evm/execution/__init__.py | 2 + tests/evm/test_shl.py | 86 +++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/evm/test_shl.py diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 287dc13a2..f5cf4bbbf 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -33,6 +33,7 @@ from .selfbalance import * from .extcodehash import * from .log import * +from .shl import * EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { @@ -65,4 +66,5 @@ ExecutionState.LOG: log, ExecutionState.CALL: call, ExecutionState.ISZERO: iszero, + ExecutionState.SHL: shl, } diff --git a/tests/evm/test_shl.py b/tests/evm/test_shl.py new file mode 100644 index 000000000..caa424dc0 --- /dev/null +++ b/tests/evm/test_shl.py @@ -0,0 +1,86 @@ +import pytest + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + Block, + Bytecode, + RWDictionary, +) +from zkevm_specs.util import ( + rand_fq, + rand_range, + rand_word, + RLC, + U256, +) + + +TESTING_DATA = ( + (0xABCD << 240, 8), + (0x1234 << 240, 7), + (0x8765 << 240, 17), + (0x4321 << 240, 0), + (0xFFFF, 256), + (0x12345, 256 + 8 + 1), + ((1 << 256) - 1, 63), + ((1 << 256) - 1, 128), + ((1 << 256) - 1, 129), +) + + +@pytest.mark.parametrize("value, shift", TESTING_DATA) +def test_shl(value: U256, shift: int): + result = value << shift if shift <= 255 else 0 + + randomness = rand_fq() + value = RLC(value, randomness) + shift = RLC(shift, randomness) + result = RLC(result, randomness) + + bytecode = Bytecode().push32(value).push32(shift).shl().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(), + bytecode_table=set(bytecode.table_assignments(randomness)), + rw_table=set( + RWDictionary(9) + .stack_read(1, 1022, value) + .stack_read(1, 1023, shift) + .stack_write(1, 1023, result) + .rws + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.SHL, + rw_counter=9, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=66, + stack_pointer=1022, + gas_left=3, + ), + StepState( + execution_state=ExecutionState.STOP, + rw_counter=11, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=67, + stack_pointer=1023, + gas_left=0, + ), + ], + ) From 35f66e086a3c782f913c1df0bcbd6481bde64a8b Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 18 May 2022 11:04:27 +0800 Subject: [PATCH 2/4] Implement opcode `SHL`. --- src/zkevm_specs/evm/execution/shl.py | 146 +++++++++++++++++++++++++++ tests/evm/test_shl.py | 10 +- 2 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 src/zkevm_specs/evm/execution/shl.py diff --git a/src/zkevm_specs/evm/execution/shl.py b/src/zkevm_specs/evm/execution/shl.py new file mode 100644 index 000000000..6bfc974d5 --- /dev/null +++ b/src/zkevm_specs/evm/execution/shl.py @@ -0,0 +1,146 @@ +from ...util import FQ, N_BYTES_U64, RLC +from ..instruction import Instruction, Transition +from ..typing import Sequence + + +def shl(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + a = instruction.stack_pop() + shift = instruction.stack_pop() + b = instruction.stack_push() + + ( + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) = gen_witness(instruction, a, shift) + check_witness( + instruction, + a, + shift, + b, + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(2), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(1), + ) + + +def check_witness( + instruction: Instruction, + a: RLC, + shift: RLC, + b: RLC, + a64s: Sequence[FQ], + b64s: Sequence[FQ], + a64s_lo: Sequence[FQ], + a64s_hi: Sequence[FQ], + shf_div64, + shf_mod64, + p_lo, + p_hi, +): + shf_lt256 = instruction.is_zero(instruction.sum(shift.le_bytes[1:])) + for idx in range(4): + offset = idx * N_BYTES_U64 + + # a64s constraint + instruction.constrain_equal( + a64s[idx], + instruction.bytes_to_fq(a.le_bytes[offset : offset + N_BYTES_U64]), + ) + + # b64s constraint + instruction.constrain_equal( + b64s[idx] * shf_lt256, + instruction.bytes_to_fq(b.le_bytes[offset : offset + N_BYTES_U64]), + ) + + # `a64s[idx] == a64s_lo[idx] + a64s_hi[idx] * p_lo` + instruction.constrain_equal(a64s[idx], a64s_lo[idx] + a64s_hi[idx] * p_lo) + + # `a64s_lo[idx] < p_lo` + # TRICKY: `p_lo` could be equal to `1 << 64` (greater than 8 bytes) if `shf_mod64` is zero. + a64s_lo_lt_p_lo, _ = instruction.compare(a64s_lo[idx], p_lo, N_BYTES_U64 + 1) + assert p_lo.expr().n <= 256**8, f"p_lo is overflow: {p_lo}" + instruction.constrain_equal(a64s_lo_lt_p_lo, FQ(1)) + + # merge contraints + shf_div64_eq0 = instruction.is_zero(shf_div64) + shf_div64_eq1 = instruction.is_zero(shf_div64 - 1) + shf_div64_eq2 = instruction.is_zero(shf_div64 - 2) + instruction.constrain_equal(b64s[0], shf_div64_eq0 * a64s_lo[0] * p_hi) + instruction.constrain_equal( + b64s[1], + shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) + shf_div64_eq1 * a64s_lo[0] * p_hi, + ) + instruction.constrain_equal( + b64s[2], + shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq2 * a64s_lo[0] * p_hi, + ) + instruction.constrain_equal( + b64s[3], + shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi) + + shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi, + ) + + # shift constraint + instruction.constrain_equal( + instruction.bytes_to_fq(shift.le_bytes[:1]), + shf_mod64 + shf_div64 * 64, + ) + + # `p_lo == pow(2, 64 - shf_mod64)` and `p_hi == pow(2, shf_mod64)`. + instruction.pow2_lookup(64 - shf_mod64, p_lo) + instruction.pow2_lookup(shf_mod64, p_hi) + + +def gen_witness(instruction: Instruction, a: RLC, shift: RLC): + shf0 = instruction.bytes_to_fq(shift.le_bytes[:1]) + shf_div64 = FQ(shf0.n // 64) + shf_mod64 = FQ(shf0.n % 64) + # Remain lower bits of `64 - shf_mod64` for SHL (reverse to SHR). + p_lo = FQ(1 << (64 - shf_mod64.n)) + p_hi = FQ(1 << shf_mod64.n) + + a64s = instruction.word_to_64s(a) + a64s_lo = [FQ(0)] * 4 + a64s_hi = [FQ(0)] * 4 + for idx in range(4): + a64s_lo[idx] = FQ(a64s[idx].n % p_lo.n) + a64s_hi[idx] = FQ(a64s[idx].n // p_lo.n) + + bb = a.int_value << shf0.n + b64s = [FQ((bb >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)] + + return ( + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) diff --git a/tests/evm/test_shl.py b/tests/evm/test_shl.py index caa424dc0..bd55c8f38 100644 --- a/tests/evm/test_shl.py +++ b/tests/evm/test_shl.py @@ -18,6 +18,8 @@ ) +TESTING_MAX_RLC = (1 << 256) - 1 + TESTING_DATA = ( (0xABCD << 240, 8), (0x1234 << 240, 7), @@ -25,15 +27,15 @@ (0x4321 << 240, 0), (0xFFFF, 256), (0x12345, 256 + 8 + 1), - ((1 << 256) - 1, 63), - ((1 << 256) - 1, 128), - ((1 << 256) - 1, 129), + (TESTING_MAX_RLC, 63), + (TESTING_MAX_RLC, 128), + (TESTING_MAX_RLC, 129), ) @pytest.mark.parametrize("value, shift", TESTING_DATA) def test_shl(value: U256, shift: int): - result = value << shift if shift <= 255 else 0 + result = value << shift & TESTING_MAX_RLC if shift <= 255 else 0 randomness = rand_fq() value = RLC(value, randomness) From 1939133601e72e9c714ed192c1e90de04dec3a12 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 18 May 2022 21:04:42 +0800 Subject: [PATCH 3/4] Fix to validate `a64s_hi[idx] < p_hi`, and add Markdown doc for `SHL`. --- specs/opcode/1bSHL.md | 134 +++++++++++++++++++++++++++ src/zkevm_specs/evm/execution/shl.py | 12 ++- 2 files changed, 141 insertions(+), 5 deletions(-) create mode 100644 specs/opcode/1bSHL.md diff --git a/specs/opcode/1bSHL.md b/specs/opcode/1bSHL.md new file mode 100644 index 000000000..c95d90478 --- /dev/null +++ b/specs/opcode/1bSHL.md @@ -0,0 +1,134 @@ +# SHL opcode + +## Procedure + +The `SHL` opcode shifts the bits towards the most significant one. The bits moved after the 256th one are discarded, the new bits are set to 0. + +### EVM behavior + +Pop two EVM words `a` and `shift` from the stack, and push `b` to the stack, where `b` is computed as: + +1. If `shift >= 256`,then `b` is set to zero. +2. If `shift < 256`,compute `b = a << shift`. + +### Circuit behavior + +To prove the `SHL` opcode, we first construct a `ShlGadget` that proves `a << shift == b` where `a, b, shift` are all 256-bit words. +As usual, we use 32 cells to represent word `a` and `b`, where each cell holds a 8-bit value. Then split each word into four 64-bit limbs denoted by `a64s[idx]` and `b64s[idx]` where idx in `(0, 1, 2, 3)`. +We put the lower `64 - n` bits of a limb into the `lo` array, and put the higher `n` bits into the `hi` array, where `n` is `shift % 64`. During the SHL operation, the `lo` array will move to higher bits of the result, and the `hi` array will move to lower bits of the result. + +The following figure illustrates how shift left works under the case of `shift < 64`. + +``` +------+-------------------------------+-------------------------------+------ + |a0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| ... +------+-------------------------------+-------------------------------+------ + | a64s[0] | a64s[1] | ... +------+------------+------------------+------------+------------------+------ + | a64s_lo[0] | a64s_hi[0] | a64s_lo[1] | a64s_hi[1] | ... +------+------------+------------------+------------+------------------+------ +b64s[0] | b64s[1] | b64s[2] +------+-------------------------------+------------+------------------------- +``` + +More formally, the variables are defined as follows: + +``` +shf0 = bytes_to_fq(shift.le_bytes[:1]) +shf_div64 = shift // 64 +shf_mod64 = shift % 64 +shf_lt256 = is_zero(sum(shift[1:])) +p_lo = 1 << (64 - shf_mod64) +p_hi = 1 << shf_mod64 +a64s = word_to_64s(a) +a64s_lo[idx] = a64s[idx] % p_lo +a64s_hi[idx] = a64s[idx] / p_lo +``` + +If `shift >= 256`, `b64s` are all 0. Otherwise, `b64s` can be calculated by `a << shf0` then split into four 64-bit limbs. + +Now putting things together, the constraints can be constructed as follows: + +1. `a64s` and `b64s` constraints: + +* First calculate `shf_lt256` as `is_zero(sum(shift[1:]))`. +* `a64s[idx]`: It should be equal to `from_bytes(a[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. +* `b64s[idx] * shf_lt256`: It should be equal to `from_bytes(b[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. + +2. `a64s_lo` and `a64s_hi` constraints: + +* `a64s[idx]`: It should be equal to `a64s_lo[idx] + a64s_hi[idx] * p_lo`. +* `a64s_hi[idx]`: It should always be less than `p_hi` (`a64s_hi[idx] < p_hi`). + +3. Merge constraints: + +* First create three `IsZero` gadgets: +``` +shf_div64_eq0 = is_zero(shf_div64) +shf_div64_eq1 = is_zero(shf_div64 - 1) +shf_div64_eq2 = is_zero(shf_div64 - 2) +``` + +* `b64s[0]` should be equal to: +``` +shf_div64_eq0 * a64s_lo[0] * p_hi +``` + +* `b64s[1]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq1 * a64s_lo[0] * p_hi +``` + +* `b64s[2]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq2 * a64s_lo[0] * p_hi +``` + +* `b64s[3]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi) + + shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi +``` + +4. `shift[0]` constraint: + +* `shift[0]`: It should be equal to `shf_mod64 + shf_div64 * 64`. + +5. `Pow2` table look up: + +* First build `Pow2` table by tuple `(value, value_pow)` which meets `value_pow == pow(2, value)` + +* Look up for `(64 - shf_mod64, p_lo)` and `(shf_mod64, p_hi)` + +6. Stack pop and push: + +* Pop word `a` +* Pop word `shift` +* Push word `shift_lt256 * b` + +## Constraints + +1. opId = OpcodeId(0x1b) +2. state transition: + - gc + 3 (2 stack reads + 1 stack write) + - stack\_pointer + 1 + - pc + 1 + - gas + 3 +3. lookups: 3 busmapping lookups + - `a` is at the top of the stack + - `shift` is at the second position of the stack + - `b`, the result, is at the new top of the stack + +## Exceptions + +1. stack underflow: `1023 <= stack_pointer <= 1024` +2. out of gas: remaining gas is not enough + +## Code + +See `src/zkevm_specs/evm/execution/shl.py` diff --git a/src/zkevm_specs/evm/execution/shl.py b/src/zkevm_specs/evm/execution/shl.py index 6bfc974d5..a5a4b6a67 100644 --- a/src/zkevm_specs/evm/execution/shl.py +++ b/src/zkevm_specs/evm/execution/shl.py @@ -76,11 +76,13 @@ def check_witness( # `a64s[idx] == a64s_lo[idx] + a64s_hi[idx] * p_lo` instruction.constrain_equal(a64s[idx], a64s_lo[idx] + a64s_hi[idx] * p_lo) - # `a64s_lo[idx] < p_lo` - # TRICKY: `p_lo` could be equal to `1 << 64` (greater than 8 bytes) if `shf_mod64` is zero. - a64s_lo_lt_p_lo, _ = instruction.compare(a64s_lo[idx], p_lo, N_BYTES_U64 + 1) - assert p_lo.expr().n <= 256**8, f"p_lo is overflow: {p_lo}" - instruction.constrain_equal(a64s_lo_lt_p_lo, FQ(1)) + # `a64s_hi[idx] < p_hi` + # + # TRICKY: + # Since `p_lo` could be equal to `1 << 64` that is greater than `N_BYTES_U64`(8 bytes) if + # `shf_mod64` is zero. Alternative to compare `a64s_hi[idx]` and `p_hi` here. + a64s_hi_lt_p_hi, _ = instruction.compare(a64s_hi[idx], p_hi, N_BYTES_U64) + instruction.constrain_equal(a64s_hi_lt_p_hi, FQ(1)) # merge contraints shf_div64_eq0 = instruction.is_zero(shf_div64) From 897ad082dbd66027416f16adb18c61e6508422a0 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 19 May 2022 09:11:17 +0800 Subject: [PATCH 4/4] Move some code from `SHR` to make it work with `master` branch. --- src/zkevm_specs/evm/instruction.py | 3 +++ src/zkevm_specs/evm/table.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index b822ed352..90840f893 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -817,3 +817,6 @@ def memory_copier_gas_cost( gas_cost = word_size * GAS_COST_COPY + memory_expansion_gas_cost self.range_check(gas_cost, N_BYTES_GAS) return gas_cost + + def pow2_lookup(self, value: Expression, value_pow: Expression): + self.fixed_lookup(FixedTableTag.Pow2, value, value_pow) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 3e0b692e9..7cb16a7dc 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -26,6 +26,7 @@ class FixedTableTag(IntEnum): BitwiseOr = auto() # lhs, rhs, lhs | rhs, 0 BitwiseXor = auto() # lhs, rhs, lhs ^ rhs, 0 ResponsibleOpcode = auto() # execution_state, opcode, aux + Pow2 = auto() # value, value_pow def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range5: @@ -68,6 +69,8 @@ def table_assignments(self) -> List[FixedTableRow]: execution_state.responsible_opcode(), ) ] + elif self == FixedTableTag.Pow2: + return [FixedTableRow(FQ(self), FQ(value), FQ(1 << value)) for value in range(65)] else: raise ValueError("Unreacheable")