diff --git a/specs/opcode/1bSHL_1cSHR.md b/specs/opcode/1bSHL_1cSHR.md new file mode 100644 index 000000000..63eb0601c --- /dev/null +++ b/specs/opcode/1bSHL_1cSHR.md @@ -0,0 +1,57 @@ +# SHL and SHR opcodes + +## Procedure + +### EVM behavior + +Pop two EVM words `shift` and `a` from the stack, and push `b` to the stack, where `b` is computed as + +- for opcode SHL, `shift` is a number of bits to shift to the left, compute `b = (a * 2^shift) % 2^256` when `shift < 256` otherwise `b = 0` +- for opcode SHR, `shift` is a number of bits to shift to the right, compute `b = a // 2^shift` when `shift < 256` otherwise `b = 0` + +### Circuit behavior + +To prove the SHL and SHR opcodes, we first construct a `MulAddWordsGadget` that proves `quotient * divisor + remainder = dividend (% 2^256)` where quotient, divisor, remainder and dividend are all 256-bit words. Reference `02MUL_04DIV_06MOD.md` for details about `MulAddWordsGadget`. + +Based on different opcode cases, we constrain the stack pops and pushes as follows + +- for opcode SHL, two stack pops are shift and quotient, when `divisor = 2^shift` if `shift < 256` and 0 otherwise. The stack push is dividend if `shift < 256` and 0 otherwise. +- for opcode SHR, two stack pops are shift and dividend, when `divisor = 2^shift` if `shift < 256` and 0 otherwise. The stack push is quotient if `shift < 256` and 0 otherwise. + +The opcode circuit also adds some extra constraints: + +- contrain `shift == shift.cells[0]` when `divisor != 0`. +- use a `LtWordGadget` to constrain `remainder < divisor` when `divisor != 0`. +- if the opcode is SHL, constrain `remainder == 0`. +- if the opcode is SHR, constrain `overflow == 0` in `MulAddWordsGadget`. + +## Constraints + +1. opcodeId checks + - opId === OpcodeId(0x1b) for SHL + - opId === OpcodeId(0x1c) for SHR +2. state transition: + - gc + 3 + - stack_pointer + 1 + - pc + 1 + - gas + 3 +3. Lookups: 1 pow2 lookup + 3 busmapping lookups + - divisor lookup in pow2 table (where 0≤shf0<256) + - when `shf0 < 128`, constrain `divisor_lo == 2^shf0`. + - when `shf0 >= 128`, constrain `divisor_hi == 2^(shf0 - 128)`. + - top of the stack + - when opcode is SHL, quotient is at the top of the stack. + - when opcode is SHR, dividend is at the top of the stack. + - shift is at the second position of the stack when `divisor = 2^shift`. + - new top of the stack + - when opcode is SHL, dividend is at the new top of the stack. + - when opcode is SHR, quotient is at the new top of the stack if `divisor != 0` otherwise 0. + +## Exceptions + +1. stack underflow: `1023 <= stack_pointer <= 1024` +2. out of gas: remaining gas is not enough + +## Code + +See `src/zkevm_specs/evm/execution/shl_shr.py` diff --git a/specs/opcode/1cSHR.md b/specs/opcode/1cSHR.md deleted file mode 100644 index 6faad0496..000000000 --- a/specs/opcode/1cSHR.md +++ /dev/null @@ -1,134 +0,0 @@ -# SHR opcode - -## Procedure - -The `SHR` opcode shifts the bits towards the least significant one. The bits moved before the first one are discarded, the new bits are set to 0. - -### EVM behavior - -Pop two EVM words `a` and `shift` from the stack, and push `b` to the stack, where `b` is computed as: - -1. If `shift >= 256`,then `b` is set to zero. -2. If `shift < 256`,compute `b = a >> shift`. - -### Circuit behavior - -To prove the `SHR` opcode, we first construct a `ShrGadget` that proves `a >> shift = b` where `a, b, shift` are all 256-bit words. -As usual, we use 32 cells to represent word `a` and `b`, where each cell holds a 8-bit value. Then split each word into four 64-bit limbs denoted by `a64s[idx]` and `b64s[idx]` where idx in `(0, 1, 2, 3)`. -We put the lower `n` bits of a limb into the `lo` array, and put the higher `64 - n` bits into the `hi` array, where `n` is `shift % 64`. During the SHR operation, the `lo` array will move to higher bits of the result, and the `hi` array will move to lower bits of the result. - -The following figure illustrates how shift right works under the case of `shift < 64`. - -``` -+-------------------------------+-------------------------------+----- -|a0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15| ... -+-------------------------------+-------------------------------+----- -| a64s[0] | a64s[1] | ... -+------------+------------------+------------+------------------+----- -| a64s_lo[0] | a64s_hi[0] | a64s_lo[1] | a64s_hi[1] | ... -+------------+------------------+------------+------------------+----- - | b64s[0] | b64s[1] - +-------------------------------+------------------------ -``` - -More formally, the variables are defined as follows: - -``` -shf0 = bytes_to_fq(shift.le_bytes[:1]) -shf_div64 = shift // 64 -shf_mod64 = shift % 64 -shf_lt256 = is_zero(sum(shift[1:])) -p_lo = 1 << shf_mod64 -p_hi = 1 << (64 - shf_mod64) -a64s = word_to_64s(a) -a64s_lo[idx] = a64s[idx] % p_lo -a64s_hi[idx] = a64s[idx] / p_lo -``` - -If `shift >= 256`, `b64s` are all 0. Otherwise, `b64s` can be calculated by `a >> shf0` then split into four 64-bit limbs. - -Now putting things together, the constraints can be constructed as follows: - -1. `a64s` and `b64s` constraints: - -* First calculate `shf_lt256` as `is_zero(sum(shift[1:]))`. -* `a64s[idx]`: It should be equal to `from_bytes(a[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. -* `b64s[idx] * shf_lt256`: It should be equal to `from_bytes(b[8 * idx : 8 * (idx + 1)])` where idx in `(0, 1, 2, 3)`. - -2. `a64s_lo` and `a64s_hi` constraints: - -* `a64s[idx]`: It should be equal to `a64s_lo[idx] + a64s_hi[idx] * p_lo`. -* `a64s_lo[idx]`: It should always be less than `p_lo` (`a64s_lo[idx] < p_lo`). - -3. Merge constraints: - -* First create three `IsZero` gadgets: -``` -shf_div64_eq0 = is_zero(shf_div64) -shf_div64_eq1 = is_zero(shf_div64 - 1) -shf_div64_eq2 = is_zero(shf_div64 - 2) -``` - -* `b64s[0]` should be equal to: -``` -(a64s_hi[0] + a64s_lo[1] * p_hi) * shf_div64_eq0 + - (a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq1 + - (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq2 + - a64s_hi[3] * (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2) -``` - -* `b64s[1]` should be equal to: -``` -(a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq0 + - (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq1 + - a64s_hi[3] * shf_div64_eq2 -``` - -* `b64s[2]` should be equal to: -``` -(a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq0 + - a64s_hi[3] * shf_div64_eq1 -``` - -* `b64s[3]` should be equal to: -``` -a64s_hi[3] * shf_div64_eq0 -``` - -4. `shift[0]` constraint: - -* `shift[0]`: It should be equal to `shf_mod64 + shf_div64 * 64`. - -5. `Pow2` table look up: - -* First build `Pow2` table by tuple $[value, value\_pow]$ which meets $${value\_pow == 2^{value}}$$ - -* Look up for `(shf_mod64, p_lo)` and `(64 - shf_mod64, p_hi)` - -6. Stack pop and push: - -* Pop word `a` -* Pop word `shift` -* Push word `shift_lt256 * b` - -## Constraints - -1. opId = OpcodeId(0x1c) -2. state transition: - - gc + 3 (2 stack reads + 1 stack write) - - stack\_pointer + 1 - - pc + 1 - - gas + 3 -3. lookups: 3 busmapping lookups - - `a` is at the top of the stack - - `shift` is at the second position of the stack - - `b`, the result, is at the new top of the stack - -## Exceptions - -1. stack underflow: `1023 <= stack_pointer <= 1024` -2. out of gas: remaining gas is not enough - -## Code - -See `src/zkevm_specs/evm/execution/shr.py` diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index d1b4db432..7e1620492 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -33,10 +33,10 @@ from .selfbalance import * from .extcodehash import * from .log import * -from .sha3 import sha3 -from .shr import shr from .bitwise import not_opcode from .sdiv_smod import sdiv_smod +from .sha3 import sha3 +from .shl_shr import shl_shr from .stop import stop from .return_ import * @@ -73,8 +73,8 @@ ExecutionState.LOG: log, ExecutionState.CALL: call, ExecutionState.ISZERO: iszero, - ExecutionState.SHR: shr, ExecutionState.SDIV_SMOD: sdiv_smod, + ExecutionState.SHL_SHR: shl_shr, ExecutionState.STOP: stop, ExecutionState.RETURN: return_, } diff --git a/src/zkevm_specs/evm/execution/shl_shr.py b/src/zkevm_specs/evm/execution/shl_shr.py new file mode 100644 index 000000000..364f16066 --- /dev/null +++ b/src/zkevm_specs/evm/execution/shl_shr.py @@ -0,0 +1,119 @@ +from ...util import FQ, RLC +from ..instruction import Instruction, Transition +from ..opcode import Opcode + + +def shl_shr(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + pop1 = instruction.stack_pop() + pop2 = instruction.stack_pop() + push = instruction.stack_push() + + ( + is_shl, + shf0, + shift, + dividend, + divisor, + quotient, + remainder, + ) = gen_witness(opcode, pop1, pop2, push) + check_witness( + instruction, + is_shl, + shf0, + shift, + dividend, + divisor, + quotient, + remainder, + pop1, + pop2, + push, + ) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(3), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(1), + ) + + +def check_witness( + instruction: Instruction, + is_shl: FQ, + shf0: FQ, + shift: RLC, + dividend: RLC, + divisor: RLC, + quotient: RLC, + remainder: RLC, + pop1: RLC, + pop2: RLC, + push: RLC, +): + is_shr = 1 - is_shl + divisor_is_zero = instruction.word_is_zero(divisor) + + # Constrain stack pops and pushes as: + # - for SHL, two pops are shift and quotient, and push is dividend. + # - for SHR, two pops are shift and dividend, and push is quotient. + instruction.constrain_equal(pop1.expr(), shift.expr()) + instruction.constrain_equal( + pop2.expr(), + is_shl * quotient.expr() + is_shr * dividend.expr(), + ) + instruction.constrain_equal( + push.expr(), (is_shl * dividend.expr() + is_shr * quotient.expr()) * (1 - divisor_is_zero) + ) + + # Constrain shift == shift.cells[0] when divisor != 0. + instruction.constrain_zero( + (1 - divisor_is_zero) * (shift.expr() - shift.le_bytes[0]), + ) + + # Constrain remainder < divisor when divisor != 0. + remainder_lt_divisor, _ = instruction.compare_word(remainder, divisor) + instruction.constrain_zero((1 - divisor_is_zero) * (1 - remainder_lt_divisor)) + + # Constrain remainder == 0 for SHL. + remainder_is_zero = instruction.word_is_zero(remainder) + instruction.constrain_zero(is_shl * (1 - remainder_is_zero)) + + # Constrain overflow == 0 for SHR. + overflow = instruction.mul_add_words(quotient, divisor, remainder, dividend) + instruction.constrain_zero(is_shr * overflow) + + # Constrain divisor_lo == 2^shf0 when shf0 < 128, and + # divisor_hi == 2^(128 - shf0) otherwise. + divisor_lo = instruction.bytes_to_fq(divisor.le_bytes[:16]) + divisor_hi = instruction.bytes_to_fq(divisor.le_bytes[16:]) + if (1 - divisor_is_zero) == 1: + instruction.pow2_lookup(shf0, divisor_lo, divisor_hi) + + +def gen_witness(opcode: FQ, pop1: RLC, pop2: RLC, push: RLC): + is_shl = Opcode.SHR - opcode + shift = pop1 + shf0 = shift.le_bytes[0] + divisor = RLC(1 << shf0) if shf0 == shift.int_value else RLC(0) + if is_shl.n == 1: + dividend = push + quotient = pop2 + remainder = RLC(0) + else: # SHR + dividend = pop2 + quotient = push + remainder = RLC(dividend.int_value - quotient.int_value * divisor.int_value) + + return ( + is_shl, + shf0, + shift, + dividend, + divisor, + quotient, + remainder, + ) diff --git a/src/zkevm_specs/evm/execution/shr.py b/src/zkevm_specs/evm/execution/shr.py deleted file mode 100644 index 16804eede..000000000 --- a/src/zkevm_specs/evm/execution/shr.py +++ /dev/null @@ -1,142 +0,0 @@ -from ...util import FQ, N_BYTES_U64, RLC -from ..instruction import Instruction, Transition -from ..typing import Sequence - - -def shr(instruction: Instruction): - opcode = instruction.opcode_lookup(True) - - a = instruction.stack_pop() - shift = instruction.stack_pop() - b = instruction.stack_push() - - ( - a64s, - b64s, - a64s_lo, - a64s_hi, - shf_div64, - shf_mod64, - p_lo, - p_hi, - ) = gen_witness(instruction, a, shift) - check_witness( - instruction, - a, - shift, - b, - a64s, - b64s, - a64s_lo, - a64s_hi, - shf_div64, - shf_mod64, - p_lo, - p_hi, - ) - - instruction.step_state_transition_in_same_context( - opcode, - rw_counter=Transition.delta(2), - program_counter=Transition.delta(1), - stack_pointer=Transition.delta(1), - ) - - -def check_witness( - instruction: Instruction, - a: RLC, - shift: RLC, - b: RLC, - a64s: Sequence[FQ], - b64s: Sequence[FQ], - a64s_lo: Sequence[FQ], - a64s_hi: Sequence[FQ], - shf_div64, - shf_mod64, - p_lo, - p_hi, -): - shf_lt256 = instruction.is_zero(instruction.sum(shift.le_bytes[1:])) - for idx in range(4): - offset = idx * N_BYTES_U64 - - # a64s constraint - instruction.constrain_equal( - a64s[idx], - instruction.bytes_to_fq(a.le_bytes[offset : offset + N_BYTES_U64]), - ) - - # b64s constraint - instruction.constrain_equal( - b64s[idx] * shf_lt256, - instruction.bytes_to_fq(b.le_bytes[offset : offset + N_BYTES_U64]), - ) - - # `a64s[idx] == a64s_lo[idx] + a64s_hi[idx] * p_lo` - instruction.constrain_equal(a64s[idx], a64s_lo[idx] + a64s_hi[idx] * p_lo) - - # `a64s_lo[idx] < p_lo` - a64s_lo_lt_p_lo, _ = instruction.compare(a64s_lo[idx], p_lo, N_BYTES_U64) - instruction.constrain_equal(a64s_lo_lt_p_lo, FQ(1)) - - # merge contraints - shf_div64_eq0 = instruction.is_zero(shf_div64) - shf_div64_eq1 = instruction.is_zero(shf_div64 - 1) - shf_div64_eq2 = instruction.is_zero(shf_div64 - 2) - instruction.constrain_equal( - b64s[0], - (a64s_hi[0] + a64s_lo[1] * p_hi) * shf_div64_eq0 - + (a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq1 - + (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq2 - + a64s_hi[3] * (1 - shf_div64_eq0 - shf_div64_eq1 - shf_div64_eq2), - ) - instruction.constrain_equal( - b64s[1], - (a64s_hi[1] + a64s_lo[2] * p_hi) * shf_div64_eq0 - + (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq1 - + a64s_hi[3] * shf_div64_eq2, - ) - instruction.constrain_equal( - b64s[2], (a64s_hi[2] + a64s_lo[3] * p_hi) * shf_div64_eq0 + a64s_hi[3] * shf_div64_eq1 - ) - instruction.constrain_equal(b64s[3], a64s_hi[3] * shf_div64_eq0) - - # shift constraint - instruction.constrain_equal( - instruction.bytes_to_fq(shift.le_bytes[:1]), - shf_mod64 + shf_div64 * 64, - ) - - # `p_lo == pow(2, shf_mod64)` and `p_hi == pow(2, 64 - shf_mod64)`. - instruction.pow2_lookup(shf_mod64, p_lo) - instruction.pow2_lookup(64 - shf_mod64, p_hi) - - -def gen_witness(instruction: Instruction, a: RLC, shift: RLC): - shf0 = instruction.bytes_to_fq(shift.le_bytes[:1]) - shf_div64 = FQ(shf0.n // 64) - shf_mod64 = FQ(shf0.n % 64) - p_lo = FQ(1 << shf_mod64.n) - p_hi = FQ(1 << (64 - shf_mod64.n)) - - a64s = instruction.word_to_64s(a) - a64s_lo = [FQ(0)] * 4 - a64s_hi = [FQ(0)] * 4 - for idx in range(4): - a64s_lo[idx] = FQ(a64s[idx].n % p_lo.n) - a64s_hi[idx] = FQ(a64s[idx].n // p_lo.n) - - bb = a.int_value >> shf0.n - b64s = [FQ((bb >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)] - - return ( - a64s, - b64s, - a64s_lo, - a64s_hi, - shf_div64, - shf_mod64, - p_lo, - p_hi, - ) diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index b61d39600..b9583ab2a 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -35,8 +35,7 @@ class ExecutionState(IntEnum): BITWISE = auto() # AND, OR, XOR NOT = auto() BYTE = auto() - SHL = auto() - SHR = auto() + SHL_SHR = auto() SAR = auto() SHA3 = auto() ADDRESS = auto() @@ -177,10 +176,8 @@ def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: return [Opcode.NOT] elif self == ExecutionState.BYTE: return [Opcode.BYTE] - elif self == ExecutionState.SHL: - return [Opcode.SHL] - elif self == ExecutionState.SHR: - return [Opcode.SHR] + elif self == ExecutionState.SHL_SHR: + return [Opcode.SHL, Opcode.SHR] elif self == ExecutionState.SAR: return [Opcode.SAR] elif self == ExecutionState.SHA3: diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index de5d28035..3d274133f 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -1009,8 +1009,8 @@ def memory_copier_gas_cost( self.range_check(gas_cost, N_BYTES_GAS) return gas_cost - def pow2_lookup(self, value: Expression, value_pow: Expression): - self.fixed_lookup(FixedTableTag.Pow2, value, value_pow) + def pow2_lookup(self, value: Expression, pow_lo128: Expression, pow_hi128: Expression): + self.fixed_lookup(FixedTableTag.Pow2, value, pow_lo128, pow_hi128) def copy_lookup( self, diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 67a65be0e..bc8f52629 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -70,7 +70,15 @@ def table_assignments(self) -> List[FixedTableRow]: ) ] elif self == FixedTableTag.Pow2: - return [FixedTableRow(FQ(self), FQ(value), FQ(1 << value)) for value in range(65)] + return [ + FixedTableRow( + FQ(self), + FQ(value), + FQ(1 << value) if value < 128 else FQ(0), + FQ(0) if value < 128 else FQ(1 << (value - 128)), + ) + for value in range(256) + ] else: raise ValueError("Unreacheable") diff --git a/tests/evm/test_shl_shr.py b/tests/evm/test_shl_shr.py new file mode 100644 index 000000000..724626858 --- /dev/null +++ b/tests/evm/test_shl_shr.py @@ -0,0 +1,100 @@ +import pytest + +from zkevm_specs.evm import ( + Block, + Bytecode, + ExecutionState, + Opcode, + RWDictionary, + StepState, + Tables, + verify_steps, +) +from zkevm_specs.util import rand_fq, rand_word, RLC +from common import generate_nasty_tests + + +MAX_WORD = (1 << 256) - 1 + +TESTING_DATA = [ + (Opcode.SHL, 8, 0xABCD << 240), + (Opcode.SHL, 7, 0x1234 << 240), + (Opcode.SHL, 17, 0x8765 << 240), + (Opcode.SHL, 0, 0x4321 << 240), + (Opcode.SHL, 256, 0xFFFF), + (Opcode.SHL, 256 + 8 + 1, 0x12345), + (Opcode.SHL, 63, MAX_WORD), + (Opcode.SHL, 128, MAX_WORD), + (Opcode.SHL, 129, MAX_WORD), + (Opcode.SHR, 8, 0xABCD), + (Opcode.SHR, 7, 0x1234), + (Opcode.SHR, 17, 0x8765), + (Opcode.SHR, 0, 0x4321), + (Opcode.SHR, 256, 0xFFFF), + (Opcode.SHR, 256 + 8 + 1, 0x12345), + (Opcode.SHR, 63, (1 << 256) - 1), + (Opcode.SHR, 128, (1 << 256) - 1), + (Opcode.SHR, 129, (1 << 256) - 1), + (Opcode.SHL, rand_word(), rand_word()), + (Opcode.SHR, rand_word(), rand_word()), +] + +generate_nasty_tests(TESTING_DATA, (Opcode.SHL, Opcode.SHR)) + + +@pytest.mark.parametrize("opcode, shift, a", TESTING_DATA) +def test_shl_shr(opcode: Opcode, shift: int, a: int): + if opcode == Opcode.SHL: + b = a << shift & MAX_WORD if shift < 256 else 0 + bytecode = Bytecode().shl(shift, a) + else: # SHR + b = a >> shift if shift < 256 else 0 + bytecode = Bytecode().shr(shift, a) + + randomness = rand_fq() + shift = RLC(shift, randomness) + a = RLC(a, randomness) + b = RLC(b, randomness) + bytecode_hash = RLC(bytecode.hash(), randomness) + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(), + bytecode_table=set(bytecode.table_assignments(randomness)), + rw_table=set( + RWDictionary(9) + .stack_read(1, 1022, shift) + .stack_read(1, 1023, a) + .stack_write(1, 1023, b) + .rws + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.SHL_SHR, + rw_counter=9, + call_id=1, + is_root=True, + is_create=False, + code_hash=bytecode_hash, + program_counter=66, + stack_pointer=1022, + gas_left=3, + ), + StepState( + execution_state=ExecutionState.STOP, + rw_counter=12, + call_id=1, + is_root=True, + is_create=False, + code_hash=bytecode_hash, + program_counter=67, + stack_pointer=1023, + gas_left=0, + ), + ], + ) diff --git a/tests/evm/test_shr.py b/tests/evm/test_shr.py deleted file mode 100644 index 758032dca..000000000 --- a/tests/evm/test_shr.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest - -from zkevm_specs.evm import ( - ExecutionState, - StepState, - verify_steps, - Tables, - Block, - Bytecode, - RWDictionary, -) -from zkevm_specs.util import ( - rand_fq, - rand_range, - rand_word, - RLC, - U256, -) - - -TESTING_DATA = ( - (0xABCD, 8), - (0x1234, 7), - (0x8765, 17), - (0x4321, 0), - (0xFFFF, 256), - (0x12345, 256 + 8 + 1), - ((1 << 256) - 1, 63), - ((1 << 256) - 1, 128), - ((1 << 256) - 1, 129), -) - - -@pytest.mark.parametrize("value, shift", TESTING_DATA) -def test_shr(value: U256, shift: int): - result = value >> shift if shift <= 255 else 0 - - randomness = rand_fq() - value = RLC(value, randomness) - shift = RLC(shift, randomness) - result = RLC(result, randomness) - - bytecode = Bytecode().push32(value).push32(shift).shr().stop() - bytecode_hash = RLC(bytecode.hash(), randomness) - - tables = Tables( - block_table=set(Block().table_assignments(randomness)), - tx_table=set(), - bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - RWDictionary(9) - .stack_read(1, 1022, value) - .stack_read(1, 1023, shift) - .stack_write(1, 1023, result) - .rws - ), - ) - - verify_steps( - randomness=randomness, - tables=tables, - steps=[ - StepState( - execution_state=ExecutionState.SHR, - rw_counter=9, - call_id=1, - is_root=True, - is_create=False, - code_hash=bytecode_hash, - program_counter=66, - stack_pointer=1022, - gas_left=3, - ), - StepState( - execution_state=ExecutionState.STOP, - rw_counter=11, - call_id=1, - is_root=True, - is_create=False, - code_hash=bytecode_hash, - program_counter=67, - stack_pointer=1023, - gas_left=0, - ), - ], - )