diff --git a/specs/opcode/1bSHL.md b/specs/opcode/1bSHL.md new file mode 100644 index 000000000..c95d90478 --- /dev/null +++ b/specs/opcode/1bSHL.md @@ -0,0 +1,134 @@ +# SHL opcode + +## Procedure + +The `SHL` opcode shifts the bits towards the most significant one. The bits moved after the 256th one are discarded, the new bits are set to 0. + +### EVM behavior + +Pop two EVM words `a` and `shift` from the stack, and push `b` to the stack, where `b` is computed as: + +1. If `shift >= 256`,then `b` is set to zero. +2. If `shift < 256`,compute `b = a << shift`. + +### Circuit behavior + +To prove the `SHL` opcode, we first construct a `ShlGadget` that proves `a << shift == b` where `a, b, shift` are all 256-bit words. +As usual, we use 32 cells to represent word `a` and `b`, where each cell holds a 8-bit value. Then split each word into four 64-bit limbs denoted by `a64s[idx]` and `b64s[idx]` where idx in `(0, 1, 2, 3)`. +We put the lower `64 - n` bits of a limb into the `lo` array, and put the higher `n` bits into the `hi` array, where `n` is `shift % 64`. During the SHL operation, the `lo` array will move to higher bits of the result, and the `hi` array will move to lower bits of the result. + +The following figure illustrates how shift left works under the case of `shift < 64`. + +``` +------+-------------------------------+-------------------------------+------ + |a0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| ... +------+-------------------------------+-------------------------------+------ + | a64s[0] | a64s[1] | ... +------+------------+------------------+------------+------------------+------ + | a64s_lo[0] | a64s_hi[0] | a64s_lo[1] | a64s_hi[1] | ... +------+------------+------------------+------------+------------------+------ +b64s[0] | b64s[1] | b64s[2] +------+-------------------------------+------------+------------------------- +``` + +More formally, the variables are defined as follows: + +``` +shf0 = bytes_to_fq(shift.le_bytes[:1]) +shf_div64 = shift // 64 +shf_mod64 = shift % 64 +shf_lt256 = is_zero(sum(shift[1:])) +p_lo = 1 << (64 - shf_mod64) +p_hi = 1 << shf_mod64 +a64s = word_to_64s(a) +a64s_lo[idx] = a64s[idx] % p_lo +a64s_hi[idx] = a64s[idx] / p_lo +``` + +If `shift >= 256`, `b64s` are all 0. Otherwise, `b64s` can be calculated by `a << shf0` then split into four 64-bit limbs. + +Now putting things together, the constraints can be constructed as follows: + +1. `a64s` and `b64s` constraints: + +* First calculate `shf_lt256` as `is_zero(sum(shift[1:]))`. +* `a64s[idx]`: It should be equal to `from_bytes(a[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. +* `b64s[idx] * shf_lt256`: It should be equal to `from_bytes(b[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. + +2. `a64s_lo` and `a64s_hi` constraints: + +* `a64s[idx]`: It should be equal to `a64s_lo[idx] + a64s_hi[idx] * p_lo`. +* `a64s_hi[idx]`: It should always be less than `p_hi` (`a64s_hi[idx] < p_hi`). + +3. Merge constraints: + +* First create three `IsZero` gadgets: +``` +shf_div64_eq0 = is_zero(shf_div64) +shf_div64_eq1 = is_zero(shf_div64 - 1) +shf_div64_eq2 = is_zero(shf_div64 - 2) +``` + +* `b64s[0]` should be equal to: +``` +shf_div64_eq0 * a64s_lo[0] * p_hi +``` + +* `b64s[1]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq1 * a64s_lo[0] * p_hi +``` + +* `b64s[2]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq2 * a64s_lo[0] * p_hi +``` + +* `b64s[3]` should be equal to: +``` +shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi) + + shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi +``` + +4. `shift[0]` constraint: + +* `shift[0]`: It should be equal to `shf_mod64 + shf_div64 * 64`. + +5. `Pow2` table look up: + +* First build `Pow2` table by tuple `(value, value_pow)` which meets `value_pow == pow(2, value)` + +* Look up for `(64 - shf_mod64, p_lo)` and `(shf_mod64, p_hi)` + +6. Stack pop and push: + +* Pop word `a` +* Pop word `shift` +* Push word `shift_lt256 * b` + +## Constraints + +1. opId = OpcodeId(0x1b) +2. state transition: + - gc + 3 (2 stack reads + 1 stack write) + - stack\_pointer + 1 + - pc + 1 + - gas + 3 +3. lookups: 3 busmapping lookups + - `a` is at the top of the stack + - `shift` is at the second position of the stack + - `b`, the result, is at the new top of the stack + +## Exceptions + +1. stack underflow: `1023 <= stack_pointer <= 1024` +2. out of gas: remaining gas is not enough + +## Code + +See `src/zkevm_specs/evm/execution/shl.py` diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 287dc13a2..f5cf4bbbf 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -33,6 +33,7 @@ from .selfbalance import * from .extcodehash import * from .log import * +from .shl import * EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { @@ -65,4 +66,5 @@ ExecutionState.LOG: log, ExecutionState.CALL: call, ExecutionState.ISZERO: iszero, + ExecutionState.SHL: shl, } diff --git a/src/zkevm_specs/evm/execution/shl.py b/src/zkevm_specs/evm/execution/shl.py new file mode 100644 index 000000000..a5a4b6a67 --- /dev/null +++ b/src/zkevm_specs/evm/execution/shl.py @@ -0,0 +1,148 @@ +from ...util import FQ, N_BYTES_U64, RLC +from ..instruction import Instruction, Transition +from ..typing import Sequence + + +def shl(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + a = instruction.stack_pop() + shift = instruction.stack_pop() + b = instruction.stack_push() + + ( + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) = gen_witness(instruction, a, shift) + check_witness( + instruction, + a, + shift, + b, + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(2), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(1), + ) + + +def check_witness( + instruction: Instruction, + a: RLC, + shift: RLC, + b: RLC, + a64s: Sequence[FQ], + b64s: Sequence[FQ], + a64s_lo: Sequence[FQ], + a64s_hi: Sequence[FQ], + shf_div64, + shf_mod64, + p_lo, + p_hi, +): + shf_lt256 = instruction.is_zero(instruction.sum(shift.le_bytes[1:])) + for idx in range(4): + offset = idx * N_BYTES_U64 + + # a64s constraint + instruction.constrain_equal( + a64s[idx], + instruction.bytes_to_fq(a.le_bytes[offset : offset + N_BYTES_U64]), + ) + + # b64s constraint + instruction.constrain_equal( + b64s[idx] * shf_lt256, + instruction.bytes_to_fq(b.le_bytes[offset : offset + N_BYTES_U64]), + ) + + # `a64s[idx] == a64s_lo[idx] + a64s_hi[idx] * p_lo` + instruction.constrain_equal(a64s[idx], a64s_lo[idx] + a64s_hi[idx] * p_lo) + + # `a64s_hi[idx] < p_hi` + # + # TRICKY: + # Since `p_lo` could be equal to `1 << 64` that is greater than `N_BYTES_U64`(8 bytes) if + # `shf_mod64` is zero. Alternative to compare `a64s_hi[idx]` and `p_hi` here. + a64s_hi_lt_p_hi, _ = instruction.compare(a64s_hi[idx], p_hi, N_BYTES_U64) + instruction.constrain_equal(a64s_hi_lt_p_hi, FQ(1)) + + # merge contraints + shf_div64_eq0 = instruction.is_zero(shf_div64) + shf_div64_eq1 = instruction.is_zero(shf_div64 - 1) + shf_div64_eq2 = instruction.is_zero(shf_div64 - 2) + instruction.constrain_equal(b64s[0], shf_div64_eq0 * a64s_lo[0] * p_hi) + instruction.constrain_equal( + b64s[1], + shf_div64_eq0 * (a64s_hi[0] + a64s_lo[1] * p_hi) + shf_div64_eq1 * a64s_lo[0] * p_hi, + ) + instruction.constrain_equal( + b64s[2], + shf_div64_eq0 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq1 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + shf_div64_eq2 * a64s_lo[0] * p_hi, + ) + instruction.constrain_equal( + b64s[3], + shf_div64_eq0 * (a64s_hi[2] + a64s_lo[3] * p_hi) + + shf_div64_eq1 * (a64s_hi[1] + a64s_lo[2] * p_hi) + + shf_div64_eq2 * (a64s_hi[0] + a64s_lo[1] * p_hi) + + (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) * a64s_lo[0] * p_hi, + ) + + # shift constraint + instruction.constrain_equal( + instruction.bytes_to_fq(shift.le_bytes[:1]), + shf_mod64 + shf_div64 * 64, + ) + + # `p_lo == pow(2, 64 - shf_mod64)` and `p_hi == pow(2, shf_mod64)`. + instruction.pow2_lookup(64 - shf_mod64, p_lo) + instruction.pow2_lookup(shf_mod64, p_hi) + + +def gen_witness(instruction: Instruction, a: RLC, shift: RLC): + shf0 = instruction.bytes_to_fq(shift.le_bytes[:1]) + shf_div64 = FQ(shf0.n // 64) + shf_mod64 = FQ(shf0.n % 64) + # Remain lower bits of `64 - shf_mod64` for SHL (reverse to SHR). + p_lo = FQ(1 << (64 - shf_mod64.n)) + p_hi = FQ(1 << shf_mod64.n) + + a64s = instruction.word_to_64s(a) + a64s_lo = [FQ(0)] * 4 + a64s_hi = [FQ(0)] * 4 + for idx in range(4): + a64s_lo[idx] = FQ(a64s[idx].n % p_lo.n) + a64s_hi[idx] = FQ(a64s[idx].n // p_lo.n) + + bb = a.int_value << shf0.n + b64s = [FQ((bb >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)] + + return ( + a64s, + b64s, + a64s_lo, + a64s_hi, + shf_div64, + shf_mod64, + p_lo, + p_hi, + ) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index b822ed352..90840f893 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -817,3 +817,6 @@ def memory_copier_gas_cost( gas_cost = word_size * GAS_COST_COPY + memory_expansion_gas_cost self.range_check(gas_cost, N_BYTES_GAS) return gas_cost + + def pow2_lookup(self, value: Expression, value_pow: Expression): + self.fixed_lookup(FixedTableTag.Pow2, value, value_pow) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 3e0b692e9..7cb16a7dc 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -26,6 +26,7 @@ class FixedTableTag(IntEnum): BitwiseOr = auto() # lhs, rhs, lhs | rhs, 0 BitwiseXor = auto() # lhs, rhs, lhs ^ rhs, 0 ResponsibleOpcode = auto() # execution_state, opcode, aux + Pow2 = auto() # value, value_pow def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range5: @@ -68,6 +69,8 @@ def table_assignments(self) -> List[FixedTableRow]: execution_state.responsible_opcode(), ) ] + elif self == FixedTableTag.Pow2: + return [FixedTableRow(FQ(self), FQ(value), FQ(1 << value)) for value in range(65)] else: raise ValueError("Unreacheable") diff --git a/tests/evm/test_shl.py b/tests/evm/test_shl.py new file mode 100644 index 000000000..bd55c8f38 --- /dev/null +++ b/tests/evm/test_shl.py @@ -0,0 +1,88 @@ +import pytest + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + Block, + Bytecode, + RWDictionary, +) +from zkevm_specs.util import ( + rand_fq, + rand_range, + rand_word, + RLC, + U256, +) + + +TESTING_MAX_RLC = (1 << 256) - 1 + +TESTING_DATA = ( + (0xABCD << 240, 8), + (0x1234 << 240, 7), + (0x8765 << 240, 17), + (0x4321 << 240, 0), + (0xFFFF, 256), + (0x12345, 256 + 8 + 1), + (TESTING_MAX_RLC, 63), + (TESTING_MAX_RLC, 128), + (TESTING_MAX_RLC, 129), +) + + +@pytest.mark.parametrize("value, shift", TESTING_DATA) +def test_shl(value: U256, shift: int): + result = value << shift & TESTING_MAX_RLC if shift <= 255 else 0 + + randomness = rand_fq() + value = RLC(value, randomness) + shift = RLC(shift, randomness) + result = RLC(result, randomness) + + bytecode = Bytecode().push32(value).push32(shift).shl().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(), + bytecode_table=set(bytecode.table_assignments(randomness)), + rw_table=set( + RWDictionary(9) + .stack_read(1, 1022, value) + .stack_read(1, 1023, shift) + .stack_write(1, 1023, result) + .rws + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.SHL, + rw_counter=9, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=66, + stack_pointer=1022, + gas_left=3, + ), + StepState( + execution_state=ExecutionState.STOP, + rw_counter=11, + call_id=1, + is_root=True, + is_create=False, + code_source=bytecode_hash, + program_counter=67, + stack_pointer=1023, + gas_left=0, + ), + ], + )