From 099503a017ec902991564756718a4d11cb935472 Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 3 Dec 2021 22:37:43 +0800 Subject: [PATCH 1/8] refactor: use RLC directly and refactor all --- src/zkevm_specs/bytecode.py | 12 +- src/zkevm_specs/evm/execution/__init__.py | 17 +- src/zkevm_specs/evm/execution/add.py | 2 +- src/zkevm_specs/evm/execution/begin_tx.py | 63 +-- .../evm/execution/block_coinbase.py | 9 +- src/zkevm_specs/evm/execution/caller.py | 2 +- src/zkevm_specs/evm/execution/jump.py | 5 +- src/zkevm_specs/evm/execution/jumpi.py | 5 +- src/zkevm_specs/evm/execution/push.py | 8 +- src/zkevm_specs/evm/instruction.py | 405 +++++++++++++----- src/zkevm_specs/evm/main.py | 48 +-- src/zkevm_specs/evm/opcode.py | 30 +- src/zkevm_specs/evm/step.py | 14 +- src/zkevm_specs/evm/table.py | 105 +++-- src/zkevm_specs/evm/typing.py | 142 ++++-- src/zkevm_specs/util/__init__.py | 8 +- src/zkevm_specs/util/arithmetic.py | 101 ++--- src/zkevm_specs/util/hash.py | 10 +- src/zkevm_specs/util/param.py | 29 ++ tests/evm/test_add.py | 41 +- tests/evm/test_begin_tx.py | 57 ++- tests/evm/test_caller.py | 28 +- tests/evm/test_coinbase.py | 31 +- tests/evm/test_jump.py | 21 +- tests/evm/test_jumpi.py | 43 +- tests/evm/test_push.py | 36 +- tests/test_bytecode_circuit.py | 98 ++--- 27 files changed, 829 insertions(+), 541 deletions(-) diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 9d5f9537e..4d2b085a7 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -1,6 +1,6 @@ from typing import Sequence, Union, Tuple, Set from collections import namedtuple -from .util import keccak256, fp_add, fp_mul, RLCStore +from .util import keccak256, fp_add, fp_mul, RLC from .evm.opcode import get_push_size from .encoding import U8, U256, is_circuit_code @@ -87,7 +87,7 @@ def check_bytecode_row( # Populate the circuit matrix -def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_store: RLCStore): +def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], randomness: int): # All rows are usable in this emulation last_row_offset = 2 ** k - 1 @@ -103,7 +103,7 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_s push_data_left = byte_push_size if is_code else push_data_left - 1 # Add the byte to the accumulator - hash_rlc = fp_add(fp_mul(hash_rlc, rlc_store.randomness), row[2]) + hash_rlc = fp_add(fp_mul(hash_rlc, randomness), row[2]) # Set the data for this row rows.append( @@ -162,10 +162,10 @@ def assign_push_table(): # Generate keccak table -def assign_keccak_table(bytecodes: Sequence[bytes], rlc_store: RLCStore): +def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int): keccak_table = [] for bytecode in bytecodes: - hash = rlc_store.to_rlc(keccak256(bytecode), 32) - rlc = rlc_store.to_rlc(list(reversed(bytecode))) + hash = RLC(bytes(reversed(keccak256(bytecode))), randomness) + rlc = RLC(bytes(reversed(bytecode)), randomness, len(bytecode)) keccak_table.append((rlc, len(bytecode), hash)) return keccak_table diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 0306dd1dd..26921e554 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -1,11 +1,24 @@ +from typing import Callable, Dict + +from ..execution_state import ExecutionState + from .begin_tx import * # Opcode's successful cases from .add import * -from .push import * from .jump import * from .jumpi import * +from .push import * from .block_coinbase import * from .caller import * -# Error cases + +EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { + ExecutionState.BeginTx: begin_tx, + ExecutionState.ADD: add, + ExecutionState.CALLER: caller, + ExecutionState.COINBASE: coinbase, + ExecutionState.JUMP: jump, + ExecutionState.JUMPI: jumpi, + ExecutionState.PUSH: push, +} diff --git a/src/zkevm_specs/evm/execution/add.py b/src/zkevm_specs/evm/execution/add.py index 9f040acef..8bd6f2e54 100644 --- a/src/zkevm_specs/evm/execution/add.py +++ b/src/zkevm_specs/evm/execution/add.py @@ -16,7 +16,7 @@ def add(instruction: Instruction): instruction.select(is_sub, a, c), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(3), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index 48dc2b29c..beeaa4ae3 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,41 +1,44 @@ +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX from ..instruction import Instruction, Transition -from ..table import CallContextFieldTag, TxContextFieldTag, RW, AccountFieldTag +from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag from ..precompiled import PrecompiledAddress -def begin_tx(instruction: Instruction, is_first_step: bool = False): - instruction.constrain_equal(instruction.curr.call_id, instruction.curr.rw_counter) +def begin_tx(instruction: Instruction): + call_id = instruction.curr.rw_counter - tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) - rw_counter_end_of_reversion = instruction.call_context_lookup(CallContextFieldTag.RWCounterEndOfReversion) - is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent) + tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=call_id) + rw_counter_end_of_reversion = instruction.call_context_lookup( + CallContextFieldTag.RwCounterEndOfReversion, call_id=call_id + ) + is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent, call_id=call_id) - if is_first_step: + if instruction.is_first_step: instruction.constrain_equal(instruction.curr.rw_counter, 1) instruction.constrain_equal(tx_id, 1) - tx_caller_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CallerAddress) - tx_callee_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CalleeAddress) - tx_is_create = instruction.tx_lookup(tx_id, TxContextFieldTag.IsCreate) - tx_value = instruction.tx_lookup(tx_id, TxContextFieldTag.Value) - tx_call_data_length = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataLength) + tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) + tx_callee_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CalleeAddress) + tx_is_create = instruction.tx_context_lookup(tx_id, TxContextFieldTag.IsCreate) + tx_value = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Value) + tx_call_data_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength) # Verify nonce - tx_nonce = instruction.tx_lookup(tx_id, TxContextFieldTag.Nonce) + tx_nonce = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Nonce) nonce, nonce_prev = instruction.account_write(tx_caller_address, AccountFieldTag.Nonce) instruction.constrain_equal(tx_nonce, nonce_prev) instruction.constrain_equal(nonce, nonce_prev + 1) # TODO: Implement EIP 1559 (currently it supports legacy transaction format) # Calculate gas fee - tx_gas = instruction.tx_lookup(tx_id, TxContextFieldTag.Gas) - tx_gas_price = instruction.tx_lookup(tx_id, TxContextFieldTag.GasPrice) + tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) + tx_gas_price = instruction.tx_gas_price(tx_id) gas_fee, carry = instruction.mul_word_by_u64(tx_gas_price, tx_gas) instruction.constrain_zero(carry) # TODO: Handle gas cost of tx level access list (EIP 2930) - tx_call_data_gas_cost = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataGasCost) - gas_left = tx_gas - (53000 if tx_is_create else 21000) - tx_call_data_gas_cost + tx_call_data_gas_cost = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataGasCost) + gas_left = tx_gas - (GAS_COST_CREATION_TX if tx_is_create else GAS_COST_TX) - tx_call_data_gas_cost instruction.constrain_gas_left_not_underflow(gas_left) # Prepare access list of caller and callee @@ -47,27 +50,27 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False): tx_caller_address, tx_callee_address, tx_value, - gas_fee=gas_fee, - is_persistent=is_persistent, - rw_counter_end_of_reversion=rw_counter_end_of_reversion, + gas_fee, + is_persistent, + rw_counter_end_of_reversion, ) if tx_is_create: # TODO: Verify receiver address - # TODO: Set opcode_source to tx_id + # TODO: Decide what code_source should be (tx_id or hash of creation code) raise NotImplementedError elif tx_callee_address in list(PrecompiledAddress): # TODO: Handle precompile raise NotImplementedError else: - code_hash, _ = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash) + code_hash = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash) # Setup next call's context # Note that: - # - CallerCallId, ReturnDataOffset, ReturnDataLength, Result - # should never be used in root call, so unnecessary to check - # - TxId is propagated from previous step or constraint to 1 if is_first_step - # - IsPersistent will be verified in the end of tx + # - CallerId, ReturnDataOffset, ReturnDataLength + # should never be used in root call, so unnecessary to be checked + # - TxId is checked from previous step or constraint to 1 if is_first_step + # - IsSuccess, IsPersistent will be verified in the end of tx for (tag, value) in [ (CallContextFieldTag.Depth, 1), (CallContextFieldTag.CallerAddress, tx_caller_address), @@ -77,14 +80,14 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False): (CallContextFieldTag.Value, tx_value), (CallContextFieldTag.IsStatic, False), ]: - instruction.constrain_equal(instruction.call_context_lookup(tag), value) + instruction.constrain_equal(instruction.call_context_lookup(tag, call_id=call_id), value) - instruction.constrain_new_context_state_transition( + instruction.step_state_transition_to_new_context( rw_counter=Transition.delta(16), - call_id=Transition.persistent(), + call_id=Transition.to(call_id), is_root=Transition.to(True), is_create=Transition.to(False), - opcode_source=Transition.to(code_hash), + code_source=Transition.to(code_hash), gas_left=Transition.to(gas_left), state_write_counter=Transition.to(2), ) diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 3b47cdc28..18c537f84 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -11,10 +11,15 @@ def coinbase(instruction: Instruction): # check block table for coinbase address instruction.constrain_equal( address, - instruction.bytes_to_rlc(instruction.int_to_bytes(instruction.block_lookup(BlockContextFieldTag.Coinbase), 20)), + instruction.bytes_to_rlc( + instruction.int_to_bytes( + instruction.block_context_lookup(BlockContextFieldTag.Coinbase), + 20, + ) + ), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 1750fb905..1d0011394 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -19,7 +19,7 @@ def caller(instruction: Instruction): ), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(2), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/jump.py b/src/zkevm_specs/evm/execution/jump.py index 454feb6c9..c21d668a3 100644 --- a/src/zkevm_specs/evm/execution/jump.py +++ b/src/zkevm_specs/evm/execution/jump.py @@ -1,3 +1,4 @@ +from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -11,13 +12,13 @@ def jump(instruction: Instruction): dest = instruction.stack_pop() # Get `dest` raw value in max 8 bytes - dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8)) + dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER) # Verify `dest` is code within byte code table # assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True) instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True)) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.to(dest_value), diff --git a/src/zkevm_specs/evm/execution/jumpi.py b/src/zkevm_specs/evm/execution/jumpi.py index 3cba2b651..6ff768c5c 100644 --- a/src/zkevm_specs/evm/execution/jumpi.py +++ b/src/zkevm_specs/evm/execution/jumpi.py @@ -1,3 +1,4 @@ +from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -16,12 +17,12 @@ def jumpi(instruction: Instruction): pc_diff = 1 else: # Get `dest` raw value in max 8 bytes - dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8)) + dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER) pc_diff = dest_value - instruction.curr.program_counter # assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True) instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True)) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(2), program_counter=Transition.delta(pc_diff), diff --git a/src/zkevm_specs/evm/execution/push.py b/src/zkevm_specs/evm/execution/push.py index de41bbb8e..3f01f6265 100644 --- a/src/zkevm_specs/evm/execution/push.py +++ b/src/zkevm_specs/evm/execution/push.py @@ -8,17 +8,17 @@ def push(instruction: Instruction): num_additional_pushed = num_pushed - 1 value = instruction.stack_push() - value_bytes = instruction.rlc_to_bytes(value, 32) + value_le_bytes = instruction.rlc_to_le_bytes(value) selectors = instruction.continuous_selectors(num_additional_pushed, 31) for idx in range(32): index = instruction.curr.program_counter + num_pushed - idx if idx == 0 or selectors[idx - 1]: - instruction.constrain_equal(value_bytes[idx], instruction.opcode_lookup_at(index, False)) + instruction.constrain_equal(value_le_bytes[idx], instruction.opcode_lookup_at(index, False)) else: - instruction.constrain_zero(value_bytes[idx]) + instruction.constrain_zero(value_le_bytes[idx]) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.delta(1 + num_pushed), diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index c8bf630e5..13cb8aa2b 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -2,7 +2,17 @@ from enum import IntEnum, auto from typing import Optional, Sequence, Tuple, Union -from ..util import Array4, Array8, linear_combine, RLCStore, MAX_N_BYTES, N_BYTES_GAS +from ..util import ( + Array4, + Array8, + RLC, + MAX_N_BYTES, + N_BYTES_MEMORY_ADDRESS, + N_BYTES_MEMORY_SIZE, + N_BYTES_GAS, + MEMORY_EXPANSION_QUAD_DENOMINATOR, + MEMORY_EXPANSION_LINEAR_COEFF, +) from .opcode import Opcode from .step import StepState from .table import ( @@ -23,7 +33,7 @@ def __init__(self, message: str) -> None: class TransitionKind(IntEnum): - Persistent = auto() + Same = auto() Delta = auto() To = auto() @@ -36,8 +46,8 @@ def __init__(self, kind: TransitionKind, value: Optional[int] = None) -> None: self.kind = kind self.value = value - def persistent() -> Transition: - return Transition(TransitionKind.Persistent) + def same() -> Transition: + return Transition(TransitionKind.Same) def delta(delta: int): return Transition(TransitionKind.Delta, delta) @@ -47,85 +57,104 @@ def to(to: int): class Instruction: - rlc_store: RLCStore + randomness: int tables: Tables curr: StepState next: StepState + # meta information + is_first_step: bool + is_last_step: bool + # helper numbers - cell_offset: int = 0 rw_counter_offset: int = 0 program_counter_offset: int = 0 stack_pointer_offset: int = 0 state_write_counter_offset: int = 0 - def __init__(self, rlc_store: RLCStore, tables: Tables, curr: StepState, next: StepState) -> None: - self.rlc_store = rlc_store + def __init__( + self, + randomness: int, + tables: Tables, + curr: StepState, + next: StepState, + is_first_step: bool, + is_last_step: bool, + ) -> None: + self.randomness = randomness self.tables = tables self.curr = curr self.next = next + self.is_first_step = is_first_step + self.is_last_step = is_last_step def constrain_zero(self, value: int): - assert value == 0 + assert value == 0, ConstraintUnsatFailure(f"Expected value to be 0, but got {value}") def constrain_equal(self, lhs: int, rhs: int): - self.constrain_zero(lhs - rhs) + assert lhs == rhs, ConstraintUnsatFailure(f"Expected values to be equal, but got {lhs} and {rhs}") def constrain_bool(self, value: int): - assert value in [0, 1] + assert value in [0, 1], ConstraintUnsatFailure(f"Expected value to be a bool, but got {value}") def constrain_gas_left_not_underflow(self, gas_left: int): self.int_to_bytes(gas_left, N_BYTES_GAS) - def constrain_state_transition(self, **kwargs: Transition): - for key in [ - "rw_counter", - "call_id", - "is_root", - "is_create", - "opcode_source", - "program_counter", - "stack_pointer", - "gas_left", - "memory_size", - "state_write_counter", - "last_callee_id", - "last_callee_return_data_offset", - "last_callee_return_data_length", - ]: + def constrain_step_state_transition(self, **kwargs: Transition): + keys = set( + [ + "rw_counter", + "call_id", + "is_root", + "is_create", + "code_source", + "program_counter", + "stack_pointer", + "gas_left", + "memory_size", + "state_write_counter", + "last_callee_id", + "last_callee_return_data_offset", + "last_callee_return_data_length", + ] + ) + + assert keys.issuperset( + kwargs.keys() + ), f"Invalid keys {list(set(kwargs.keys()).difference(keys))} for step state transition" + + for key in keys: curr, next = getattr(self.curr, key), getattr(self.next, key) - transition = kwargs.get(key, Transition.persistent()) - if transition.kind == TransitionKind.Persistent: - assert next == curr, ConstraintUnsatFailure( - f"state {key} should be persistent as {curr}, but got {next}" - ) + transition = kwargs.get(key, Transition.same()) + if transition.kind == TransitionKind.Same: + assert next == curr, ConstraintUnsatFailure(f"State {key} should be same as {curr}, but got {next}") elif transition.kind == TransitionKind.Delta: assert next == curr + transition.value, ConstraintUnsatFailure( - f"state {key} should transit to ${curr + transition.value}, but got {next}" + f"State {key} should transit to {curr + transition.value}, but got {next}" ) elif transition.kind == TransitionKind.To: assert next == transition.value, ConstraintUnsatFailure( - f"state {key} should transit to ${transition.value}, but got {next}" + f"State {key} should transit to {transition.value}, but got {next}" ) else: raise ValueError("unreacheable") - def constrain_new_context_state_transition( + def step_state_transition_to_new_context( self, rw_counter: Transition, call_id: Transition, is_root: Transition, is_create: Transition, - opcode_source: Transition, + code_source: Transition, gas_left: Transition, state_write_counter: Transition, ): - self.constrain_state_transition( + self.constrain_step_state_transition( rw_counter=rw_counter, call_id=call_id, is_root=is_root, is_create=is_create, - opcode_source=opcode_source, + code_source=code_source, gas_left=gas_left, state_write_counter=state_write_counter, # Initailization unconditionally @@ -137,19 +166,21 @@ def constrain_new_context_state_transition( last_callee_return_data_length=Transition.to(0), ) - def constrain_same_context_state_transition( + def step_state_transition_in_same_context( self, opcode: int, - rw_counter: Transition = Transition.persistent(), - program_counter: Transition = Transition.persistent(), - stack_pointer: Transition = Transition.persistent(), - memory_size: Transition = Transition.persistent(), + rw_counter: Transition = Transition.same(), + program_counter: Transition = Transition.same(), + stack_pointer: Transition = Transition.same(), + memory_size: Transition = Transition.same(), dynamic_gas_cost: int = 0, ): - gas_cost = Opcode(opcode).constant_gas_cost() + dynamic_gas_cost + self.responsible_opcode_lookup(opcode) + gas_cost = Opcode(opcode).constant_gas_cost() + dynamic_gas_cost self.constrain_gas_left_not_underflow(self.curr.gas_left - gas_cost) - self.constrain_state_transition( + + self.constrain_step_state_transition( rw_counter=rw_counter, program_counter=program_counter, stack_pointer=stack_pointer, @@ -157,10 +188,19 @@ def constrain_same_context_state_transition( gas_left=Transition.delta(-gas_cost), ) - def is_zero(self, value: int) -> bool: + def sum(self, values: Sequence[int]) -> int: + return sum(values) + + def is_zero(self, value: Union[int, RLC]) -> bool: + if isinstance(value, RLC): + value = value.value return value == 0 - def is_equal(self, lhs: int, rhs: int) -> bool: + def is_equal(self, lhs: Union[int, RLC], rhs: Union[int, RLC]) -> bool: + if isinstance(lhs, RLC): + lhs = lhs.value + if isinstance(rhs, RLC): + rhs = rhs.value return self.is_zero(lhs - rhs) def continuous_selectors(self, t: int, n: int) -> Sequence[int]: @@ -172,22 +212,38 @@ def select(self, condition: bool, when_true: int, when_false: int) -> int: def pair_select(self, value: int, lhs: int, rhs: int) -> Tuple[bool, bool]: return value == lhs, value == rhs - def add_words(self, addends: Sequence[int]) -> Tuple[int, int]: - def rlc_to_lo_hi(rlc: int) -> Tuple[Sequence[int], Sequence[int]]: - bytes = self.rlc_to_bytes(rlc, 32) - return self.bytes_to_int(bytes[:16]), self.bytes_to_int(bytes[16:]) + def constant_divmod(self, numerator: int, denominator: int, n_bytes: int) -> Tuple[int, int]: + quotient, remainder = divmod(numerator, denominator) + self.int_to_bytes(quotient, n_bytes) + return quotient, remainder - addends_lo, addends_hi = list(zip(*map(rlc_to_lo_hi, addends))) - carry_lo, sum_lo = divmod(sum(addends_lo), 1 << 128) - carry_hi, sum_hi = divmod(sum(addends_hi) + carry_lo, 1 << 128) + def compare(self, lhs: int, rhs: int, n_bytes: int) -> Tuple[bool, bool]: + assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + + return lhs < rhs, lhs == rhs + + def min(self, lhs: int, rhs: int, n_bytes: int) -> int: + lt, _ = self.compare(lhs, rhs, n_bytes) + return self.select(lt, lhs, rhs) + + def max(self, lhs: int, rhs: int, n_bytes: int) -> int: + lt, _ = self.compare(lhs, rhs, n_bytes) + return self.select(lt, rhs, lhs) + + def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, int]: + addends_lo, addends_hi = list(zip(*map(self.word_to_lo_hi, addends))) + + carry_lo, sum_lo = divmod(self.sum(addends_lo), 1 << 128) + carry_hi, sum_hi = divmod(self.sum(addends_hi) + carry_lo, 1 << 128) sum_bytes = sum_lo.to_bytes(16, "little") + sum_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(sum_bytes), carry_hi + return RLC(sum_bytes, self.randomness), carry_hi + + def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: + minuend_lo, minuend_hi = self.word_to_lo_hi(minuend) + subtrahend_lo, subtrahend_hi = self.word_to_lo_hi(subtrahend) - def sub_word(self, minuend: int, subtrahend: int) -> Tuple[int, bool]: - minuend_lo, minuend_hi = self.rlc_to_lo_hi(minuend) - subtrahend_lo, subtrahend_hi = self.rlc_to_lo_hi(subtrahend) borrow_lo = minuend_lo < subtrahend_lo diff_lo = minuend_lo - subtrahend_lo + (1 << 128 if borrow_lo else 0) borrow_hi = minuend_hi < subtrahend_hi + borrow_lo @@ -195,56 +251,81 @@ def sub_word(self, minuend: int, subtrahend: int) -> Tuple[int, bool]: diff_bytes = diff_lo.to_bytes(16, "little") + diff_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(diff_bytes), borrow_hi - - def mul_word_by_u64(self, multiplicand: int, multiplier: int) -> Tuple[int, int]: - multiplicand_bytes = self.rlc_to_bytes(multiplicand, 32) + return RLC(diff_bytes, self.randomness), borrow_hi - multiplicand_lo = self.bytes_to_int(multiplicand_bytes[:16]) - multiplicand_hi = self.bytes_to_int(multiplicand_bytes[16:]) + def mul_word_by_u64(self, multiplicand: RLC, multiplier: int) -> Tuple[RLC, int]: + multiplicand_lo, multiplicand_hi = self.word_to_lo_hi(multiplicand) quotient_lo, product_lo = divmod(multiplicand_lo * multiplier, 1 << 128) quotient_hi, product_hi = divmod(multiplicand_hi * multiplier + quotient_lo, 1 << 128) product_bytes = product_lo.to_bytes(16, "little") + product_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(product_bytes), quotient_hi + return RLC(product_bytes, self.randomness), quotient_hi + + def rlc_to_le_bytes(self, rlc: RLC) -> Sequence[int]: + return rlc.le_bytes + + def rlc_to_int_unchecked(self, rlc: RLC, n_bytes: int) -> int: + rlc_le_bytes = self.rlc_to_le_bytes(rlc) + return self.bytes_to_int(rlc_le_bytes[:n_bytes]), self.is_zero(self.sum(rlc_le_bytes[n_bytes:])) + + def rlc_to_int_exact(self, rlc: RLC, n_bytes: int) -> int: + rlc_le_bytes = self.rlc_to_le_bytes(rlc) - def rlc_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: - bytes = self.rlc_store.to_bytes(value) - if len(bytes) > n_bytes and any(bytes[n_bytes:]): - raise ConstraintUnsatFailure(f"{value} is too many bytes to fit {n_bytes} bytes") - return list(bytes) + (n_bytes - len(bytes)) * [0] + if sum(rlc_le_bytes[n_bytes:]) > 0: + raise ConstraintUnsatFailure(f"Value {rlc} has too many bytes to fit {n_bytes} bytes") - def bytes_to_rlc(self, bytes: Sequence[int]) -> int: - return self.rlc_store.to_rlc(bytes) + return self.bytes_to_int(rlc_le_bytes[:n_bytes]) + + def word_to_lo_hi(self, word: RLC) -> Tuple[Sequence[int], Sequence[int]]: + word_le_bytes = self.rlc_to_le_bytes(word) + assert len(word_le_bytes) == 32, "Expected RLC to contain 32 bytes" + return self.bytes_to_int(word_le_bytes[:16]), self.bytes_to_int(word_le_bytes[16:]) + + def bytes_to_rlc(self, bytes: Sequence[int]) -> RLC: + return RLC(bytes, self.randomness) def bytes_to_int(self, bytes: Sequence[int]) -> int: - assert len(bytes) <= MAX_N_BYTES, "too many bytes to composite an integer in field" - return linear_combine(bytes, 256) + assert len(bytes) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + + return int.from_bytes(bytes, "little") def int_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: - assert n_bytes <= MAX_N_BYTES, "too many bytes to composite an integer in field" + assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + try: return value.to_bytes(n_bytes, "little") except OverflowError: - raise ConstraintUnsatFailure(f"{value} is too many bytes to fit {n_bytes} bytes") + raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") + + def range_lookup(self, input: int, range: int): + self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), input, 0, 0]) def byte_range_lookup(self, input: int): - self.tables.fixed_lookup([FixedTableTag.Range256, input, 0, 0]) + self.range_lookup(input, 256) def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[int]) -> Array4: return self.tables.fixed_lookup([tag] + inputs) - def block_lookup(self, tag: BlockContextFieldTag, index: int = 0) -> int: + def block_context_lookup(self, tag: BlockContextFieldTag, index: int = 0) -> int: return self.tables.block_lookup([tag, index])[2] - def tx_lookup(self, tx_id: int, tag: TxContextFieldTag, index: int = 0) -> int: - return self.tables.tx_lookup([tx_id, tag, index])[3] + def tx_context_lookup(self, tx_id: int, field_tag: TxContextFieldTag) -> Union[int, RLC]: + return self.tables.tx_lookup([tx_id, field_tag, 0])[3] + + def tx_calldata_lookup(self, tx_id: int, index: int) -> int: + return self.tables.tx_lookup([tx_id, TxContextFieldTag.CallData, index])[3] def bytecode_lookup(self, bytecode_hash: int, index: int, is_code: int) -> int: return self.tables.bytecode_lookup([bytecode_hash, index, Tables._, is_code])[2] + def tx_gas_price(self, tx_id: int) -> int: + return self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice) + + def responsible_opcode_lookup(self, opcode: int): + self.fixed_lookup(FixedTableTag.ResponsibleOpcode, [self.curr.execution_state, opcode]) + def opcode_lookup(self, is_code: bool) -> int: index = self.curr.program_counter + self.program_counter_offset self.program_counter_offset += 1 @@ -257,7 +338,7 @@ def opcode_lookup_at(self, index: int, is_code: bool) -> int: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.opcode_source, index, is_code) + return self.bytecode_lookup(self.curr.code_source, index, is_code) def rw_lookup(self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None) -> Array8: if rw_counter is None: @@ -285,48 +366,67 @@ def state_write_with_reversion( inputs: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> Array8: assert tag.write_with_reversion() row = self.rw_lookup(RW.Write, tag, inputs) - rw_counter = rw_counter_end_of_reversion - self.curr.state_write_counter - self.state_write_counter_offset - self.state_write_counter_offset += 1 + if state_write_counter is None: + state_write_counter = self.curr.state_write_counter + self.state_write_counter_offset + self.state_write_counter_offset += 1 + + rw_counter = rw_counter_end_of_reversion - state_write_counter if not is_persistent: # Swap value and value_prev inputs = list(row[3:]) if tag == RWTableTag.TxAccessListAccount: inputs[2], inputs[3] = inputs[3], inputs[2] - elif tag == RWTableTag.TxAccessListStorageSlot: + elif tag == RWTableTag.TxAccessListAccountStorage: inputs[3], inputs[4] = inputs[4], inputs[3] elif tag == RWTableTag.Account: inputs[2], inputs[3] = inputs[3], inputs[2] elif tag == RWTableTag.AccountStorage: - inputs[3], inputs[4] = inputs[4], inputs[3] + inputs[2], inputs[3] = inputs[3], inputs[2] self.rw_lookup(RW.Write, tag, inputs, rw_counter=rw_counter) return row - def call_context_lookup(self, tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Union[int, None] = None) -> int: + def call_context_lookup( + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[int] = None + ) -> int: if call_id is None: call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, tag])[5] + return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag])[5] - def stack_pop(self) -> int: + def stack_pop(self) -> Union[int, RLC]: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 return self.stack_lookup(False, stack_pointer_offset) - def stack_push(self) -> int: + def stack_push(self) -> Union[int, RLC]: self.stack_pointer_offset -= 1 return self.stack_lookup(True, self.stack_pointer_offset) - def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> int: + def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[int, RLC]: stack_pointer = self.curr.stack_pointer + stack_pointer_offset return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer])[5] + def memory_write(self, memory_address: int, call_id: Optional[int] = None) -> int: + return self.memory_lookup(RW.Write, memory_address, call_id) + + def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = None) -> int: + if call_id is None: + call_id = self.curr.call_id + + return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[5] + + def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> int: + row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) + return row[5] + def account_write( self, account_address: int, @@ -345,20 +445,23 @@ def account_write_with_reversion( account_field_tag: AccountFieldTag, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> Tuple[int, int]: row = self.state_write_with_reversion( RWTableTag.Account, [account_address, account_field_tag], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[5], row[6] - def add_balance(self, account_address: int, values: Sequence[int]): + def add_balance(self, account_address: int, values: Sequence[int]) -> Tuple[int, int]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result) self.constrain_zero(carry) + return balance, balance_prev def add_balance_with_reversion( self, @@ -366,19 +469,22 @@ def add_balance_with_reversion( values: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - ): + state_write_counter: Optional[int] = None, + ) -> Tuple[int, int]: balance, balance_prev = self.account_write_with_reversion( - account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion + account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion, state_write_counter ) result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result) self.constrain_zero(carry) + return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[int]): + def sub_balance(self, account_address: int, values: Sequence[int]) -> Tuple[int, int]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result) self.constrain_zero(carry) + return balance, balance_prev def sub_balance_with_reversion( self, @@ -386,17 +492,15 @@ def sub_balance_with_reversion( values: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - ): + state_write_counter: Optional[int] = None, + ) -> Tuple[int, int]: balance, balance_prev = self.account_write_with_reversion( - account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion + account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion, state_write_counter ) result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result) self.constrain_zero(carry) - - def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> Tuple[int, int]: - row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) - return row[5], row[6] + return balance, balance_prev def add_account_to_access_list( self, @@ -416,40 +520,45 @@ def add_account_to_access_list_with_reversion( account_address: int, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> bool: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccount, [tx_id, account_address, 1], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[5] - row[6] - def add_storage_slot_to_access_list( + def add_account_storage_to_access_list( self, tx_id: int, account_address: int, - storage_slot: int, + storage_key: int, ) -> bool: - row = self.state_write_with_reversion( - RWTableTag.TxAccessListAccount, - [tx_id, account_address, storage_slot, 1], + row = self.rw_lookup( + RW.Write, + RWTableTag.TxAccessListAccountStorage, + [tx_id, account_address, storage_key, 1], ) return row[6] - row[7] - def add_storage_slot_to_access_list_with_reversion( + def add_account_storage_to_access_list_with_reversion( self, tx_id: int, account_address: int, - storage_slot: int, + storage_key: int, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> bool: row = self.state_write_with_reversion( - RWTableTag.TxAccessListAccount, - [tx_id, account_address, storage_slot, 1], + RWTableTag.TxAccessListAccountStorage, + [tx_id, account_address, storage_key, 1], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[6] - row[7] @@ -461,16 +570,88 @@ def transfer_with_gas_fee( gas_fee: int, is_persistent: bool, rw_counter_end_of_reversion: int, - ): - self.sub_balance_with_reversion( + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + sender_balance_pair = self.sub_balance_with_reversion( sender_address, [value, gas_fee], is_persistent, rw_counter_end_of_reversion, ) - self.add_balance_with_reversion( + receiver_balance_pair = self.add_balance_with_reversion( receiver_address, [value], is_persistent, rw_counter_end_of_reversion, ) + return sender_balance_pair, receiver_balance_pair + + def transfer( + self, + sender_address: int, + receiver_address: int, + value: int, + is_persistent: bool, + rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + sender_balance_pair = self.sub_balance_with_reversion( + sender_address, + [value], + is_persistent, + rw_counter_end_of_reversion, + state_write_counter, + ) + receiver_balance_pair = self.add_balance_with_reversion( + receiver_address, + [value], + is_persistent, + rw_counter_end_of_reversion, + None if state_write_counter is None else state_write_counter + 1, + ) + return sender_balance_pair, receiver_balance_pair + + def memory_offset_and_length_to_int(self, offset: RLC, length: RLC) -> Tuple[int, int]: + length = self.rlc_to_int_exact(length, N_BYTES_MEMORY_ADDRESS) + if self.is_zero(length): + return 0, 0 + + offset = self.rlc_to_int_exact(offset, N_BYTES_MEMORY_ADDRESS) + return offset, length + + def memory_gas_cost(self, memory_size: int) -> int: + quadratic_cost, _ = self.constant_divmod( + memory_size * memory_size, MEMORY_EXPANSION_QUAD_DENOMINATOR, N_BYTES_GAS + ) + linear_cost = MEMORY_EXPANSION_LINEAR_COEFF * memory_size + return quadratic_cost + linear_cost + + def memory_expansion_constant_length(self, offset: int, length: int) -> Tuple[int, int]: + memory_size, _ = self.constant_divmod(length + offset + 31, 32, N_BYTES_MEMORY_SIZE) + + next_memory_size = self.max(self.curr.memory_size, memory_size, N_BYTES_MEMORY_SIZE) + + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) + memory_gas_cost_next = self.memory_gas_cost(next_memory_size) + memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost + + return next_memory_size, memory_expansion_gas_cost + + def memory_expansion_dynamic_length( + self, + cd_offset: int, + cd_length: int, + rd_offset: Optional[int] = None, + rd_length: Optional[int] = None, + ) -> Tuple[int, int]: + cd_memory_size, _ = self.constant_divmod(cd_offset + cd_length + 31, 32, N_BYTES_MEMORY_SIZE) + next_memory_size = self.max(self.curr.memory_size, cd_memory_size, N_BYTES_MEMORY_SIZE) + + if rd_offset is not None: + rd_memory_size, _ = self.constant_divmod(rd_offset + rd_length + 31, 32, N_BYTES_MEMORY_SIZE) + next_memory_size = self.max(next_memory_size, rd_memory_size, N_BYTES_MEMORY_SIZE) + + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) + memory_gas_cost_next = self.memory_gas_cost(next_memory_size) + memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost + + return next_memory_size, memory_expansion_gas_cost diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index 0564289dc..50dcff42e 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -1,15 +1,6 @@ from typing import Sequence -from ..util.arithmetic import RLCStore -from .execution import ( - add, - begin_tx, - push, - jump, - jumpi, - coinbase, - caller, -) +from .execution import EXECUTION_STATE_IMPL from .execution_state import ExecutionState from .instruction import Instruction from .step import StepState @@ -17,7 +8,7 @@ def verify_steps( - rlc_store: RLCStore, + randomness: int, tables: Tables, steps: Sequence[StepState], begin_with_first_step: bool = False, @@ -25,40 +16,29 @@ def verify_steps( ): for idx in range(len(steps) - 1): verify_step( - Instruction(rlc_store=rlc_store, tables=tables, curr=steps[idx], next=steps[idx + 1]), - begin_with_first_step and idx == 0, - end_with_final_step and idx == len(steps) - 2, + Instruction( + randomness=randomness, + tables=tables, + curr=steps[idx], + next=steps[idx + 1], + is_first_step=begin_with_first_step and idx == 0, + is_last_step=end_with_final_step and idx == len(steps) - 2, + ), ) def verify_step( instruction: Instruction, - is_first_step: bool = False, - is_final_step: bool = False, ): - if is_first_step: + if instruction.is_first_step: instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.BeginTx) - if instruction.curr.execution_state == ExecutionState.BeginTx: - begin_tx(instruction, is_first_step) - # Opcode's successful cases - elif instruction.curr.execution_state == ExecutionState.ADD: - add(instruction) - elif instruction.curr.execution_state == ExecutionState.PUSH: - push(instruction) - elif instruction.curr.execution_state == ExecutionState.JUMP: - jump(instruction) - elif instruction.curr.execution_state == ExecutionState.JUMPI: - jumpi(instruction) - elif instruction.curr.execution_state == ExecutionState.COINBASE: - coinbase(instruction) - elif instruction.curr.execution_state == ExecutionState.CALLER: - caller(instruction) - # Error cases + if instruction.curr.execution_state in EXECUTION_STATE_IMPL: + EXECUTION_STATE_IMPL[instruction.curr.execution_state](instruction) else: raise NotImplementedError - if is_final_step: + if instruction.is_last_step: # Verify no malicious insertion assert instruction.curr.rw_counter == len(instruction.tables.rw_table) diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 362126dac..21c24aedd 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -149,6 +149,24 @@ class Opcode(IntEnum): def hex(self) -> str: return "{:02x}".format(self) + def bytes(self) -> bytes: + return bytes([self]) + + def is_push(self) -> bool: + return Opcode.PUSH1 <= self <= Opcode.PUSH32 + + def is_dup(self) -> bool: + return Opcode.DUP1 <= self <= Opcode.DUP16 + + def is_swap(self) -> bool: + return Opcode.SWAP1 <= self <= Opcode.SWAP16 + + def max_stack_pointer(self) -> int: + return OPCODE_INFO_MAP[self].max_stack_pointer + + def min_stack_pointer(self) -> int: + return OPCODE_INFO_MAP[self].min_stack_pointer + def constant_gas_cost(self) -> int: return OPCODE_INFO_MAP[self].constant_gas_cost @@ -339,15 +357,14 @@ def valid_opcodes() -> Sequence[Opcode]: def invalid_opcodes() -> Sequence[int]: - return [opcode for opcode in range(256) if opcode not in OPCODE_INFO_MAP] + return [opcode for opcode in range(256) if opcode not in valid_opcodes()] def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: pairs = [] for opcode in valid_opcodes(): - opcode_info = OPCODE_INFO_MAP[opcode] - if opcode_info.min_stack_pointer > 0: - for stack_pointer in range(opcode_info.min_stack_pointer): + if opcode.min_stack_pointer() > 0: + for stack_pointer in range(opcode.min_stack_pointer()): pairs.append((opcode, stack_pointer)) return pairs @@ -355,9 +372,8 @@ def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: def stack_underflow_pairs() -> Sequence[Tuple[int, int]]: pairs = [] for opcode in valid_opcodes(): - opcode_info = OPCODE_INFO_MAP[opcode] - if opcode_info.max_stack_pointer < 1024: - for stack_pointer in range(opcode_info.max_stack_pointer, 1024): + if opcode.max_stack_pointer() < 1024: + for stack_pointer in range(opcode.max_stack_pointer(), 1024): pairs.append((opcode, stack_pointer + 1)) return pairs diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index dd89ef1ab..ea5ccf8cf 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -6,7 +6,7 @@ class StepState: Step state EVM circuit tracks step by step and used to ensure the execution trace is verified continuously and chronologically. It includes fields that are used from beginning to end like is_root, - is_create and opcode_source. + is_create and code_source. It also includes call's mutable states which change almost every step like program_counter and stack_pointer. """ @@ -18,15 +18,15 @@ class StepState: # The following 3 fields decide the opcode source. There are 2 possible # cases: # 1. Root creation call (is_root and is_create) - # It was planned to set the opcode_source to tx_id, then lookup tx_table's + # It was planned to set the code_source to tx_id, then lookup tx_table's # CallData field directly, but is still yet to be determined. # See the issue https://github.com/appliedzkp/zkevm-specs/issues/73 for # further discussion. # 2. Deployed contract interaction or internal creation call - # We set opcode_source to bytecode_hash and lookup bytecode_table. + # We set code_source to bytecode_hash and lookup bytecode_table. is_root: bool is_create: bool - opcode_source: int + code_source: int # The following fields change almost every step. program_counter: int @@ -45,10 +45,10 @@ def __init__( self, execution_state: ExecutionState, rw_counter: int, - call_id: int, + call_id: int = 0, is_root: bool = False, is_create: bool = False, - opcode_source: int = 0, + code_source: int = 0, program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, @@ -63,7 +63,7 @@ def __init__( self.call_id = call_id self.is_root = is_root self.is_create = is_create - self.opcode_source = opcode_source + self.code_source = code_source self.program_counter = program_counter self.stack_pointer = stack_pointer self.gas_left = gas_left diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index a970b09d2..4c151c682 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,5 +1,7 @@ +from __future__ import annotations from typing import Sequence, Set, Tuple from enum import IntEnum, auto +from itertools import chain, product from ..util import Array3, Array4, Array8 from .execution_state import ExecutionState @@ -38,6 +40,62 @@ class FixedTableTag(IntEnum): StackOverflow = auto() # opcode, stack_pointer, 0 StackUnderflow = auto() # opcode, stack_pointer, 0 + def table_assignments(self) -> Sequence[Array4]: + if self == FixedTableTag.Range16: + return [(self, i, 0, 0) for i in range(16)] + elif self == FixedTableTag.Range32: + return [(self, i, 0, 0) for i in range(32)] + elif self == FixedTableTag.Range64: + return [(self, i, 0, 0) for i in range(64)] + elif self == FixedTableTag.Range256: + return [(self, i, 0, 0) for i in range(256)] + elif self == FixedTableTag.Range512: + return [(self, i, 0, 0) for i in range(512)] + elif self == FixedTableTag.Range1024: + return [(self, i, 0, 0) for i in range(1024)] + elif self == FixedTableTag.SignByte: + return [(self, i, (i & 1) * 0xFF, 0) for i in range(256)] + elif self == FixedTableTag.BitwiseAnd: + return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.BitwiseOr: + return [(self, lhs, rhs, lhs | rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.BitwiseXor: + return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.ResponsibleOpcode: + return [ + (self, execution_state, opcode, 0) + for execution_state in list(ExecutionState) + for opcode in execution_state.responsible_opcode() + ] + elif self == FixedTableTag.InvalidOpcode: + return [(self, opcode, 0, 0) for opcode in invalid_opcodes()] + elif self == FixedTableTag.StateWriteOpcode: + return [(self, opcode, 0, 0) for opcode in state_write_opcodes()] + elif self == FixedTableTag.StackOverflow: + return [(self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_underflow_pairs()] + elif self == FixedTableTag.StackUnderflow: + return [(self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs()] + else: + ValueError("Unreacheable") + + def range_table_tag(range: int) -> FixedTableTag: + if range == 16: + return FixedTableTag.Range16 + elif range == 32: + return FixedTableTag.Range32 + elif range == 64: + return FixedTableTag.Range64 + elif range == 256: + return FixedTableTag.Range256 + elif range == 512: + return FixedTableTag.Range512 + elif range == 1024: + return FixedTableTag.Range1024 + else: + raise ValueError( + f"Range {range} lookup is not supported yet, please add a new variant Range{range} in FixedTableTag with proper table assignments" + ) + class BlockContextFieldTag(IntEnum): """ @@ -89,7 +147,7 @@ class RWTableTag(IntEnum): """ TxAccessListAccount = auto() - TxAccessListStorageSlot = auto() + TxAccessListAccountStorage = auto() TxRefund = auto() Account = auto() @@ -105,7 +163,7 @@ class RWTableTag(IntEnum): def write_with_reversion(self) -> bool: return self in [ RWTableTag.TxAccessListAccount, - RWTableTag.TxAccessListStorageSlot, + RWTableTag.TxAccessListAccountStorage, RWTableTag.Account, RWTableTag.AccountStorage, ] @@ -138,8 +196,8 @@ class CallContextFieldTag(IntEnum): # It's not like transaction or bytecode that require specifically friendly # layout for verification, so maintaining the consistency directly in # RWTable seems more intuitive than creating another table for it. - RWCounterEndOfReversion = auto() # to know at which point in the future we should revert - CallerCallId = auto() # to know caller's id + RwCounterEndOfReversion = auto() # to know at which point in the future we should revert + CallerId = auto() # to know caller's id TxId = auto() # to know tx's id Depth = auto() # to know if call too deep CallerAddress = auto() @@ -149,10 +207,17 @@ class CallContextFieldTag(IntEnum): ReturnDataOffset = auto() # for callee to set return_data to caller's memeory ReturnDataLength = auto() Value = auto() - Result = auto() # to peek result in the future + IsSuccess = auto() # to peek result in the future IsPersistent = auto() # to know if current call is within reverted call or not IsStatic = auto() # to know if state modification is within static call or not + # The following are read-only data inside a call like previous section for + # opcode RETURNDATASIZE and RETURNDATACOPY, except they will be updated when + # end of callee execution. + LastCalleeId = auto() + LastCalleeReturnDataOffset = auto() + LastCalleeReturnDataLength = auto() + # The following are used by caller to save its own CallState when it's # going to dive into another call, and will be read out to restore caller's # CallState in the end by callee. @@ -161,7 +226,7 @@ class CallContextFieldTag(IntEnum): # different kinds of RWTableTag. IsRoot = auto() IsCreate = auto() - OpcodeSource = auto() + CodeSource = auto() ProgramCounter = auto() StackPointer = auto() GasLeft = auto() @@ -193,33 +258,7 @@ class Tables: # - value1 # - value2 # - value3 - fixed_table: Set[Array4] = set( - [(FixedTableTag.Range16, i, 0, 0) for i in range(16)] - + [(FixedTableTag.Range32, i, 0, 0) for i in range(32)] - + [(FixedTableTag.Range64, i, 0, 0) for i in range(64)] - + [(FixedTableTag.Range256, i, 0, 0) for i in range(256)] - + [(FixedTableTag.Range512, i, 0, 0) for i in range(512)] - + [(FixedTableTag.Range1024, i, 0, 0) for i in range(1024)] - + [(FixedTableTag.SignByte, i, (i & 1) * 0xFF, 0) for i in range(256)] - + [(FixedTableTag.BitwiseAnd, lhs, rhs, lhs & rhs) for lhs in range(256) for rhs in range(256)] - + [(FixedTableTag.BitwiseOr, lhs, rhs, lhs | rhs) for lhs in range(256) for rhs in range(256)] - + [(FixedTableTag.BitwiseXor, lhs, rhs, lhs ^ rhs) for lhs in range(256) for rhs in range(256)] - + [ - (FixedTableTag.ResponsibleOpcode, execution_state, opcode, 0) - for execution_state in list(ExecutionState) - for opcode in execution_state.responsible_opcode() - ] - + [(FixedTableTag.InvalidOpcode, opcode, 0, 0) for opcode in invalid_opcodes()] - + [(FixedTableTag.StateWriteOpcode, opcode, 0, 0) for opcode in state_write_opcodes()] - + [ - (FixedTableTag.StackUnderflow, opcode, stack_pointer, 0) - for (opcode, stack_pointer) in stack_underflow_pairs() - ] - + [ - (FixedTableTag.StackOverflow, opcode, stack_pointer, 0) - for (opcode, stack_pointer) in stack_overflow_pairs() - ] - ) + fixed_table: Set[Array4] = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) # Each row in BlockTable contains: # - tag diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index f3bce88fa..f78a592d8 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,4 +1,5 @@ -from typing import Iterator, Optional, Sequence, Union +from __future__ import annotations +from typing import Any, Iterator, Optional, Sequence from functools import reduce from itertools import chain @@ -8,13 +9,15 @@ U256, Array3, Array4, - RLCStore, + RLC, keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, + EMPTY_HASH, + EMPTY_TRIE_HASH, ) from .table import BlockContextFieldTag, TxContextFieldTag -from .opcode import get_push_size +from .opcode import get_push_size, Opcode class Block: @@ -49,16 +52,16 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, rlc_store: RLCStore) -> Sequence[Array3]: + def table_assignments(self, randomness: int) -> Sequence[Array3]: return [ (BlockContextFieldTag.Coinbase, 0, self.coinbase), (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.BlockNumber, 0, rlc_store.to_rlc(self.block_number, 32)), - (BlockContextFieldTag.Time, 0, rlc_store.to_rlc(self.time, 32)), - (BlockContextFieldTag.Difficulty, 0, rlc_store.to_rlc(self.difficulty, 32)), - (BlockContextFieldTag.BaseFee, 0, rlc_store.to_rlc(self.base_fee, 32)), + (BlockContextFieldTag.BlockNumber, 0, RLC(self.block_number, randomness)), + (BlockContextFieldTag.Time, 0, RLC(self.time, randomness)), + (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), + (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), ] + [ - (BlockContextFieldTag.BlockHash, self.block_number - idx - 1, rlc_store.to_rlc(block_hash, 32)) + (BlockContextFieldTag.BlockHash, self.block_number - idx - 1, RLC(block_hash, randomness)) for idx, block_hash in enumerate(reversed(self.history_hashes)) ] @@ -93,63 +96,106 @@ def __init__( self.value = value self.call_data = call_data - def table_assignments(self, rlc_store: RLCStore) -> Iterator[Array4]: - def call_data_gas_cost_per_byte(byte: int): - return GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE if byte is 0 else GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE + def call_data_gas_cost(self) -> int: + return reduce( + lambda acc, byte: ( + acc + (GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE if byte is 0 else GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE) + ), + self.call_data, + 0, + ) - call_data_gas_cost = reduce(lambda acc, byte: acc + call_data_gas_cost_per_byte(byte), self.call_data, 0) + def table_assignments(self, randomness: int) -> Iterator[Array4]: return chain( [ (self.id, TxContextFieldTag.Nonce, 0, self.nonce), (self.id, TxContextFieldTag.Gas, 0, self.gas), - (self.id, TxContextFieldTag.GasPrice, 0, rlc_store.to_rlc(self.gas_price, 32)), + (self.id, TxContextFieldTag.GasPrice, 0, RLC(self.gas_price, randomness)), (self.id, TxContextFieldTag.CallerAddress, 0, self.caller_address), (self.id, TxContextFieldTag.CalleeAddress, 0, self.callee_address), (self.id, TxContextFieldTag.IsCreate, 0, self.callee_address is None), - (self.id, TxContextFieldTag.Value, 0, rlc_store.to_rlc(self.value, 32)), + (self.id, TxContextFieldTag.Value, 0, RLC(self.value, randomness)), (self.id, TxContextFieldTag.CallDataLength, 0, len(self.call_data)), - (self.id, TxContextFieldTag.CallDataGasCost, 0, call_data_gas_cost), + (self.id, TxContextFieldTag.CallDataGasCost, 0, self.call_data_gas_cost()), ], map(lambda item: (self.id, TxContextFieldTag.CallData, item[0], item[1]), enumerate(self.call_data)), ) class Bytecode: - hash: U256 - bytes: bytes - - def __init__( - self, - str_or_bytes: Union[str, bytes], - ): - if type(str_or_bytes) is str: - str_or_bytes = bytes.fromhex(str_or_bytes) - - self.hash = int.from_bytes(keccak256(str_or_bytes), "little") - self.bytes = str_or_bytes - - def table_assignments(self, rlc_store: RLCStore) -> Iterator[Array4]: + code: bytearray + + def __init__(self, code: Optional[bytearray] = None) -> None: + self.code = bytearray() if code is None else code + + def __getattr__(self, name: str): + def method(*args) -> Bytecode: + try: + opcode = Opcode[name.removesuffix("_").upper()] + except KeyError: + raise ValueError(f"Invalid opcode {name}") + + if opcode.is_push(): + assert len(args) == 1 + self.push(args[0], opcode - Opcode.PUSH1 + 1) + elif opcode.is_dup() or opcode.is_swap(): + assert len(args) == 0 + self.code.append(opcode) + else: + assert len(args) <= 1024 - opcode.max_stack_pointer() + for arg in reversed(args): + self.push(arg, 32) + self.code.append(opcode) + + return self + + return method + + def push(self, value: Any, n_bytes: int = 32) -> Bytecode: + if isinstance(value, int): + value = value.to_bytes(n_bytes, "big") + elif isinstance(value, str): + value = bytes.fromhex(value.lower().removeprefix("0x")) + elif isinstance(value, RLC): + value = value.be_bytes() + elif isinstance(value, bytes) or isinstance(value, bytearray): + ... + else: + raise NotImplementedError(f"Value of type {type(value)} is not yet supported") + + assert 0 < len(value) <= n_bytes, ValueError("Too many bytes as data portion of PUSH*") + + opcode = Opcode.PUSH1 + n_bytes - 1 + self.code.append(opcode) + self.code.extend(value.rjust(n_bytes, bytes(1))) + + return self + + def hash(self) -> int: + return int.from_bytes(keccak256(self.code), "big") + + def table_assignments(self, randomness: int) -> Iterator[Array4]: class BytecodeIterator: idx: int push_data_left: int - hash: int - bytes: bytes + hash: RLC + code: bytes - def __init__(self, hash: int, bytes: bytes): + def __init__(self, hash: RLC, code: bytes): self.idx = 0 self.push_data_left = 0 self.hash = hash - self.bytes = bytes + self.code = code def __iter__(self): return self def __next__(self): - if self.idx == len(self.bytes): + if self.idx == len(self.code): raise StopIteration idx = self.idx - byte = self.bytes[idx] + byte = self.code[idx] is_code = self.push_data_left == 0 self.push_data_left = get_push_size(byte) if is_code else self.push_data_left - 1 @@ -158,4 +204,26 @@ def __next__(self): return (self.hash, idx, byte, is_code) - return BytecodeIterator(rlc_store.to_rlc(self.hash, 32), self.bytes) + return BytecodeIterator(RLC(self.hash(), randomness), self.code) + + +class Account: + address: U160 + nonce: U256 + balance: U256 + code_hash: U256 + storage_trie_hash: U256 + + def __init__( + self, + address: U160 = 0, + nonce: U256 = 0, + balance: U256 = 0, + code_hash: U256 = EMPTY_HASH, + storage_trie_hash: U256 = EMPTY_TRIE_HASH, + ) -> None: + self.address = address + self.nonce = nonce + self.balance = balance + self.code_hash = code_hash + self.storage_trie_hash = storage_trie_hash diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 2bfab45de..3c8552c4b 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -8,14 +8,14 @@ from .typing import * -def hex_to_word(hex: str) -> bytes: - return bytes.fromhex(hex.removeprefix("0x").zfill(64)) - - def rand_range(stop: Union[int, float] = 2 ** 256) -> int: return randrange(0, int(stop)) +def rand_fp() -> int: + return rand_range(FP_MODULUS) + + def rand_address() -> U160: return rand_range(2 ** 160) diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index 8f70b6265..068ef6d32 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -1,9 +1,5 @@ -from typing import Dict, Sequence, Tuple, Union -from Crypto.Random import get_random_bytes -from Crypto.Random.random import randrange - -from .param import MAX_N_BYTES - +from __future__ import annotations +from typing import Sequence, Union # BN254 scalar field size FP_MODULUS = 21888242871839275222246405745257275088548364400416034343698204186575808495617 @@ -21,70 +17,43 @@ def fp_inv(value: int) -> int: return pow(value, -1, FP_MODULUS) -def le_to_int(bytes: Sequence[int]) -> int: - assert len(bytes) <= MAX_N_BYTES, "too many bytes to composite an integer in field" - return linear_combine(bytes, 256) +def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: int) -> int: + com = 0 + for byte in reversed(le_bytes): + assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" + com = fp_add(fp_mul(com, factor), byte) + return com + +class RLC: + le_bytes: bytes + value: int -def linear_combine(bytes: Sequence[int], r: int) -> int: - ret = 0 - for byte in reversed(bytes): - assert 0 <= byte < 256, "bytes for linear combination should be already checked in range" - ret = fp_add(fp_mul(ret, r), byte) - return ret + def __init__(self, int_or_bytes: Union[int, bytes], randomness: int, n_bytes: int = 32) -> None: + if isinstance(int_or_bytes, int): + assert 0 <= int_or_bytes < 256 ** n_bytes, f"Value {int_or_bytes} too large to fit {n_bytes} bytes" + self.le_bytes = int_or_bytes.to_bytes(n_bytes, "little") + elif isinstance(int_or_bytes, bytes): + assert len(int_or_bytes) <= n_bytes, f"Expected bytes with length less or equal than {n_bytes}" + self.le_bytes = int_or_bytes.ljust(n_bytes, b"\x00") + else: + raise TypeError(f"Expected an int or bytes, but got object of type {type(int_or_bytes)}") + self.value = fp_linear_combine(self.le_bytes, randomness) -class RLCStore: - randomness: int - rlc_to_bytes: Dict[int, bytes] = dict() + def __eq__(self, rhs: Union[int, RLC]): + if isinstance(rhs, int): + return self.value == rhs + if isinstance(rhs, RLC): + return self.value == rhs.value + else: + raise TypeError(f"Expected a RLC, but got object of type {type(rhs)}") - def __init__(self, randomness: int = randrange(0, FP_MODULUS)) -> None: - self.randomness = randomness - for byte in range(256): - self.to_rlc([byte]) + def __hash__(self) -> int: + return self.value - def to_rlc(self, seq_or_int: Union[Sequence[int], int], n_bytes: int = 0) -> int: - seq = seq_or_int - if type(seq_or_int) == int: - seq = seq_or_int.to_bytes(n_bytes, "little") - rlc = linear_combine(seq, self.randomness) + def __repr__(self) -> str: + return int.from_bytes(self.le_bytes, "little").__repr__() - if rlc in self.rlc_to_bytes: - maxlen = max(len(self.rlc_to_bytes[rlc]), len(seq)) - assert self.rlc_to_bytes[rlc].rjust(maxlen, b"\x00") == bytes(seq).rjust( - maxlen, b"\x00" - ), f"Random lienar combination collision on {self.rlc_to_bytes[rlc]} and {bytes(seq)} with randomness {self.randomness}" - else: - self.rlc_to_bytes[rlc] = bytes(seq) - - return rlc - - def to_bytes(self, rlc: int) -> bytes: - return self.rlc_to_bytes[rlc] - - def rand(self, n_bytes: int = 32) -> Tuple[int, bytes]: - bytes = get_random_bytes(n_bytes) - return self.to_rlc(bytes), bytes - - def add(self, lhs: int, rhs: int, modulus: int = 2 ** 256) -> Tuple[int, bytes, bool]: - lhs_bytes = self.to_bytes(lhs) - rhs_bytes = self.to_bytes(rhs) - carry, result = divmod( - int.from_bytes(lhs_bytes, "little") + int.from_bytes(rhs_bytes, "little"), - modulus, - ) - result_bytes = result.to_bytes(32, "little") - return self.to_rlc(result_bytes), result_bytes, carry > 0 - - def sub(self, lhs: int, rhs: int, modulus: int = 2 ** 256) -> Tuple[int, bytes, bool]: - lhs_bytes = self.to_bytes(lhs) - rhs_bytes = self.to_bytes(rhs) - borrow, result = divmod( - int.from_bytes(lhs_bytes, "little") - int.from_bytes(rhs_bytes, "little"), - modulus, - ) - assert ( - result + int.from_bytes(rhs_bytes, "little") == int.from_bytes(lhs_bytes, "little") + (borrow < 0) * modulus - ) - result_bytes = result.to_bytes(32, "little") - return self.to_rlc(result_bytes), result_bytes, borrow < 0 + def be_bytes(self) -> bytes: + return bytes(reversed(self.le_bytes)) diff --git a/src/zkevm_specs/util/hash.py b/src/zkevm_specs/util/hash.py index f196e3d6f..99d750cff 100644 --- a/src/zkevm_specs/util/hash.py +++ b/src/zkevm_specs/util/hash.py @@ -1,13 +1,15 @@ from typing import Union from Crypto.Hash import keccak +from .typing import U256 -def keccak256(data: Union[str, bytes]) -> bytes: + +def keccak256(data: Union[str, bytes, bytearray]) -> bytes: if type(data) == str: data = bytes.fromhex(data) return keccak.new(digest_bits=256).update(data).digest() -EMPTY_HASH = keccak256("") -EMPTY_CODE_HASH = EMPTY_HASH -EMPTY_TRIE_HASH = keccak256("80") +EMPTY_HASH: U256 = int.from_bytes(keccak256(""), "big") +EMPTY_CODE_HASH: U256 = EMPTY_HASH +EMPTY_TRIE_HASH: U256 = int.from_bytes(keccak256("80"), "big") diff --git a/src/zkevm_specs/util/param.py b/src/zkevm_specs/util/param.py index 71c0192d8..87ff0e194 100644 --- a/src/zkevm_specs/util/param.py +++ b/src/zkevm_specs/util/param.py @@ -1,9 +1,38 @@ # Maximun number of bytes with composition value that doesn't wrap around the field MAX_N_BYTES = 31 +# Number of bytes of account address +N_BYTES_ACCOUNT_ADDRESS = 20 +# Number of bytes of memory address +N_BYTES_MEMORY_ADDRESS = 5 +# Number of bytes of memory size (in word) +N_BYTES_MEMORY_SIZE = 4 # Number of bytes of gas N_BYTES_GAS = 8 +# Number of bytes of program counter +N_BYTES_PROGRAM_COUNTER = 8 +# Gas cost of non-creation transaction +GAS_COST_TX = 21000 +# Gas cost of creation transaction +GAS_COST_CREATION_TX = 53000 # Gas cost of transaction call_data per non-zero byte GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE = 16 # Gas cost of transaction call_data per zero byte GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE = 4 +# Gas cost of accessing account or storage slot +GAS_COST_WARM_ACCESS = 100 +# Extra gas cost of not-yet-accessed account +EXTRA_GAS_COST_ACCOUNT_COLD_ACCESS = 2500 +# Extra gas cost of not-yet-accessed storage slot +EXTRA_GAS_COST_STORAGE_SLOT_COLD_ACCESS = 2000 +# Gas cost of calling with non-zero value +GAS_COST_CALL_WITH_VALUE = 9000 +# Gas cost of calling empty account +GAS_COST_CALL_EMPTY_ACCOUNT = 25000 +# Gas stipend given if call with non-zero value +GAS_STIPEND_CALL_WITH_VALUE = 2300 + +# Denominator of quadratic part of memory expansion gas cost +MEMORY_EXPANSION_QUAD_DENOMINATOR = 512 +# Coefficient of linear part of memory expansion gas cost +MEMORY_EXPANSION_LINEAR_COEFF = 3 diff --git a/tests/evm/test_add.py b/tests/evm/test_add.py index 8ba0371b8..d54a97ddf 100644 --- a/tests/evm/test_add.py +++ b/tests/evm/test_add.py @@ -12,37 +12,32 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, rand_word, RLC TESTING_DATA = ( - (Opcode.ADD, hex_to_word("030201"), hex_to_word("060504"), hex_to_word("090705")), - (Opcode.SUB, hex_to_word("090705"), hex_to_word("060504"), hex_to_word("030201")), - (Opcode.ADD, rand_bytes(), rand_bytes(), None), - (Opcode.SUB, rand_bytes(), rand_bytes(), None), + (Opcode.ADD, 0x030201, 0x060504, 0x090705), + (Opcode.SUB, 0x090705, 0x060504, 0x030201), + (Opcode.ADD, rand_word(), rand_word(), None), + (Opcode.SUB, rand_word(), rand_word(), None), ) -@pytest.mark.parametrize("opcode, a_bytes, b_bytes, c_bytes", TESTING_DATA) -def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[bytes]): - rlc_store = RLCStore() +@pytest.mark.parametrize("opcode, a, b, c", TESTING_DATA) +def test_add(opcode: Opcode, a: int, b: int, c: Optional[int]): + randomness = rand_fp() - a = rlc_store.to_rlc(a_bytes) - b = rlc_store.to_rlc(b_bytes) - c = ( - rlc_store.to_rlc(c_bytes) - if c_bytes is not None - else (rlc_store.add(a, b) if opcode == Opcode.ADD else rlc_store.sub(a, b))[0] - ) + c = RLC(c, randomness) if c is not None else RLC((a + b if opcode == Opcode.ADD else a - b) % 2 ** 256, randomness) + a = RLC(a, randomness) + b = RLC(b, randomness) - block = Block() - bytecode = Bytecode(f"7f{b_bytes.hex()}7f{a_bytes.hex()}{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().add(a, b) if opcode == Opcode.ADD else Bytecode().sub(a, b) + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1022, a, 0, 0), @@ -53,7 +48,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -62,7 +57,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=66, stack_pointer=1022, gas_left=3, @@ -73,7 +68,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=67, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_begin_tx.py b/tests/evm/test_begin_tx.py index 9908a97a6..934871c7e 100644 --- a/tests/evm/test_begin_tx.py +++ b/tests/evm/test_begin_tx.py @@ -13,7 +13,7 @@ Transaction, Bytecode, ) -from zkevm_specs.util import RLCStore, rand_address, rand_range +from zkevm_specs.util import rand_fp, rand_address, rand_range, RLC TESTING_DATA = ( # Transfer 1 ether, successfully @@ -41,31 +41,31 @@ @pytest.mark.parametrize("tx, result", TESTING_DATA) def test_begin_tx(tx: Transaction, result: bool): - rlc_store = RLCStore() + randomness = rand_fp() - block = Block() - caller_balance_prev = rlc_store.to_rlc(int(1e20), 32) - callee_balance_prev = rlc_store.to_rlc(0, 32) - caller_balance = rlc_store.to_rlc(int(1e20) - (tx.value + tx.gas * tx.gas_price), 32) - callee_balance = rlc_store.to_rlc(tx.value, 32) + rw_counter_end_of_reversion = 23 + caller_balance_prev = int(1e20) + callee_balance_prev = 0 + caller_balance = caller_balance_prev - (tx.value + tx.gas * tx.gas_price) + callee_balance = callee_balance_prev + tx.value - bytecode = Bytecode("00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), - tx_table=set(tx.table_assignments(rlc_store)), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ - (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, 1, 0, 0), + (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), ( 2, RW.Read, RWTableTag.CallContext, 1, - CallContextFieldTag.RWCounterEndOfReversion, - 0 if result else 20, + CallContextFieldTag.RwCounterEndOfReversion, + 0 if result else rw_counter_end_of_reversion, 0, 0, ), @@ -79,8 +79,8 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, - caller_balance, - caller_balance_prev, + RLC(caller_balance, randomness), + RLC(caller_balance_prev, randomness), 0, ), ( @@ -89,8 +89,8 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.Account, tx.callee_address, AccountFieldTag.Balance, - callee_balance, - callee_balance_prev, + RLC(callee_balance, randomness), + RLC(callee_balance_prev, randomness), 0, ), ( @@ -114,7 +114,7 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.CallContext, 1, CallContextFieldTag.Value, - rlc_store.to_rlc(tx.value, 32), + RLC(tx.value, randomness), 0, 0, ), @@ -125,23 +125,23 @@ def test_begin_tx(tx: Transaction, result: bool): if result else [ ( - 19, + rw_counter_end_of_reversion - 1, RW.Write, RWTableTag.Account, tx.callee_address, AccountFieldTag.Balance, - callee_balance_prev, - callee_balance, + RLC(callee_balance_prev, randomness), + RLC(callee_balance, randomness), 0, ), ( - 20, + rw_counter_end_of_reversion, RW.Write, RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, - caller_balance_prev, - caller_balance, + RLC(caller_balance_prev, randomness), + RLC(caller_balance, randomness), 0, ), ] @@ -150,13 +150,12 @@ def test_begin_tx(tx: Transaction, result: bool): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( execution_state=ExecutionState.BeginTx, rw_counter=1, - call_id=1, ), StepState( execution_state=ExecutionState.STOP if result else ExecutionState.REVERT, @@ -164,7 +163,7 @@ def test_begin_tx(tx: Transaction, result: bool): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=0, diff --git a/tests/evm/test_caller.py b/tests/evm/test_caller.py index 9668f3af0..0e5c31ef9 100644 --- a/tests/evm/test_caller.py +++ b/tests/evm/test_caller.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,33 @@ CallContextFieldTag, Bytecode, ) -from zkevm_specs.util import RLCStore, U160 +from zkevm_specs.util import rand_address, rand_fp, RLC, U160 -TESTING_DATA = ((Opcode.CALLER, 0x030201),) +TESTING_DATA = (0x030201, rand_address()) -@pytest.mark.parametrize("opcode, address", TESTING_DATA) -def test_caller(opcode: Opcode, address: U160): - rlc_store = RLCStore() +@pytest.mark.parametrize("caller", TESTING_DATA) +def test_caller(caller: U160): + randomness = rand_fp() - caller_rlc = rlc_store.to_rlc(address.to_bytes(20, "little")) + bytecode = Bytecode().caller() + bytecode_hash = RLC(bytecode.hash(), randomness) - bytecode = Bytecode(f"{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) tables = Tables( block_table=set(), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, caller_rlc, 0, 0), - (10, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, address, 0, 0), + (9, RW.Write, RWTableTag.Stack, 1, 1023, RLC(caller, randomness, 20), 0, 0), + (10, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, caller, 0, 0), ] ), ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -47,7 +45,7 @@ def test_caller(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=2, @@ -58,7 +56,7 @@ def test_caller(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_coinbase.py b/tests/evm/test_coinbase.py index e82f43995..c63506fd0 100644 --- a/tests/evm/test_coinbase.py +++ b/tests/evm/test_coinbase.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,34 @@ Block, Bytecode, ) -from zkevm_specs.util import RLCStore, U160 +from zkevm_specs.util import rand_address, rand_fp, RLC, U160 -TESTING_DATA = ((Opcode.COINBASE, 0x030201),) +TESTING_DATA = (0x030201, rand_address()) -@pytest.mark.parametrize("opcode, address", TESTING_DATA) -def test_coinbase(opcode: Opcode, address: U160): - rlc_store = RLCStore() +@pytest.mark.parametrize("coinbase", TESTING_DATA) +def test_coinbase(coinbase: U160): + randomness = rand_fp() - coinbase_rlc = rlc_store.to_rlc(address.to_bytes(20, "little")) + block = Block(coinbase=coinbase) + + bytecode = Bytecode().coinbase() + bytecode_hash = RLC(bytecode.hash(), randomness) - bytecode = Bytecode(f"{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) - block = Block(address) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, coinbase_rlc, 0, 0), + (9, RW.Write, RWTableTag.Stack, 1, 1023, RLC(coinbase, randomness, 20), 0, 0), ] ), ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -47,7 +46,7 @@ def test_coinbase(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=2, @@ -58,7 +57,7 @@ def test_coinbase(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_jump.py b/tests/evm/test_jump.py index 8e6cb1769..06ea99e65 100644 --- a/tests/evm/test_jump.py +++ b/tests/evm/test_jump.py @@ -1,6 +1,5 @@ import pytest -from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -12,7 +11,7 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, RLC TESTING_DATA = ((Opcode.JUMP, bytes([7])),) @@ -20,19 +19,19 @@ @pytest.mark.parametrize("opcode, dest_bytes", TESTING_DATA) def test_jump(opcode: Opcode, dest_bytes: bytes): - rlc_store = RLCStore() - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 # PUSH1 80 PUSH1 40 PUSH1 07 JUMP JUMPDEST STOP - bytecode = Bytecode(f"6080604060{dest_bytes.hex()}565b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(0x40).push1(dest_bytes).jump().jumpdest().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -41,7 +40,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -50,7 +49,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=8, @@ -61,7 +60,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=int.from_bytes(dest_bytes, "little"), stack_pointer=1022, gas_left=0, diff --git a/tests/evm/test_jumpi.py b/tests/evm/test_jumpi.py index 221796fee..6c7105c7b 100644 --- a/tests/evm/test_jumpi.py +++ b/tests/evm/test_jumpi.py @@ -1,6 +1,5 @@ import pytest -from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -12,7 +11,7 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, RLC TESTING_DATA = ((Opcode.JUMPI, bytes([40]), bytes([7])),) @@ -20,20 +19,20 @@ @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA) def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - rlc_store = RLCStore() - cond = rlc_store.to_rlc(bytes(reversed(cond_bytes))) - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + cond = RLC(bytes(reversed(cond_bytes)), randomness) + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 because the condition (40) is nonzero. # PUSH1 80 PUSH1 40 PUSH1 07 JUMPI JUMPDEST STOP - bytecode = Bytecode(f"6080604060{dest_bytes.hex()}575b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(0x40).push1(dest_bytes).jumpi().jumpdest().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -43,7 +42,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -52,7 +51,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=10, @@ -63,7 +62,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=int.from_bytes(dest_bytes, "little"), stack_pointer=1023, gas_left=0, @@ -77,20 +76,20 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA_ZERO_COND) def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - rlc_store = RLCStore() - cond = rlc_store.to_rlc(bytes(reversed(cond_bytes))) - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + cond = RLC(bytes(reversed(cond_bytes)), randomness) + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 because the condition (0) is zero. # PUSH1 80 PUSH1 0 PUSH1 08 JUMPI STOP - bytecode = Bytecode(f"6080600060{dest_bytes.hex()}575b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(cond_bytes).push1(dest_bytes).jumpi().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -100,7 +99,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -109,7 +108,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=10, @@ -120,7 +119,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=7, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_push.py b/tests/evm/test_push.py index d663a2aaf..252b11a17 100644 --- a/tests/evm/test_push.py +++ b/tests/evm/test_push.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,33 @@ Block, Bytecode, ) -from zkevm_specs.util import rand_bytes, RLCStore +from zkevm_specs.util import rand_bytes, rand_fp, RLC TESTING_DATA = tuple( [ - (Opcode.PUSH1, bytes([1])), - (Opcode.PUSH2, bytes([2, 1])), - (Opcode.PUSH31, bytes([i for i in range(31, 0, -1)])), - (Opcode.PUSH32, bytes([i for i in range(32, 0, -1)])), + (bytes([1])), + (bytes([2, 1])), + (bytes([i for i in range(31, 0, -1)])), + (bytes([i for i in range(32, 0, -1)])), ] - + [(Opcode(Opcode.PUSH1 + i), rand_bytes(i + 1)) for i in range(32)] + + [(rand_bytes(i + 1)) for i in range(32)] ) -@pytest.mark.parametrize("opcode, value_be_bytes", TESTING_DATA) -def test_push(opcode: Opcode, value_be_bytes: bytes): - rlc_store = RLCStore() +@pytest.mark.parametrize("value_be_bytes", TESTING_DATA) +def test_push(value_be_bytes: bytes): + randomness = rand_fp() - value = rlc_store.to_rlc(bytes(reversed(value_be_bytes))) + value = RLC(bytes(reversed(value_be_bytes)), randomness) - block = Block() - bytecode = Bytecode(f"{opcode.hex()}{value_be_bytes.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push(value_be_bytes, n_bytes=len(value_be_bytes)) + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (8, RW.Write, RWTableTag.Stack, 1, 1023, value, 0, 0), @@ -47,7 +45,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -56,7 +54,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=3, @@ -67,7 +65,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1 + len(value_be_bytes), stack_pointer=1023, gas_left=0, diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index cc7308e7f..99c0a0739 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -3,22 +3,22 @@ import traceback from copy import deepcopy from zkevm_specs.evm import Bytecode -from zkevm_specs.util import RLCStore +from zkevm_specs.util import RLC, rand_fp # Unroll the bytecode -def unroll(bytecode, rlc_store): - return UnrolledBytecode(bytecode, list(Bytecode(bytecode).table_assignments(rlc_store))) +def unroll(bytecode, randomness): + return UnrolledBytecode(bytecode, list(Bytecode(bytecode).table_assignments(randomness))) # Verify the bytecode circuit with the given data -def verify(k, bytecodes, rlc_store, success): +def verify(k, bytecodes, randomness, success): push_table = assign_push_table() - keccak_table = assign_keccak_table(map(lambda v: v.bytes, bytecodes), rlc_store) - rows = assign_bytecode_circuit(k, bytecodes, rlc_store) + keccak_table = assign_keccak_table(map(lambda v: v.bytes, bytecodes), randomness) + rows = assign_bytecode_circuit(k, bytecodes, randomness) try: for (idx, row) in enumerate(rows): prev_row = rows[(idx - 1) % len(rows)] - check_bytecode_row(row, prev_row, push_table, keccak_table, rlc_store.randomness) + check_bytecode_row(row, prev_row, push_table, keccak_table, randomness) ok = True except AssertionError as e: if success: @@ -28,7 +28,7 @@ def verify(k, bytecodes, rlc_store, success): k = 10 -rlc_store = RLCStore() +randomness = rand_fp() def test_bytecode_unrolling(): @@ -48,110 +48,104 @@ def test_bytecode_unrolling(): for _ in range(n): rows.append((0, len(rows), data_byte, False)) # Set the hash of the complete bytecode in the rows - hash = rlc_store.to_rlc(keccak256(bytes(bytecode)), 32) + hash = RLC(bytes(reversed(keccak256(bytes(bytecode)))), randomness) for i in range(len(rows)): rows[i] = (hash, rows[i][1], rows[i][2], rows[i][3]) # Unroll the bytecode - unrolled = unroll(bytes(bytecode), rlc_store) + unrolled = unroll(bytes(bytecode), randomness) # Check if the bytecode was unrolled correctly assert UnrolledBytecode(bytes(bytecode), rows) == unrolled # Verify the unrolling in the circuit - verify(k, [unrolled], rlc_store, True) + verify(k, [unrolled], randomness, True) def test_bytecode_empty(): - bytecodes = [ - unroll(bytes([]), rlc_store), - ] - verify(k, bytecodes, rlc_store, True) + bytecodes = [unroll(bytes([]), randomness)] + verify(k, bytecodes, randomness, True) def test_bytecode_full(): - bytecodes = [ - unroll(bytes([7] * 2 ** k), rlc_store), - ] - verify(k, bytecodes, rlc_store, True) + bytecodes = [unroll(bytes([7] * 2 ** k), randomness)] + verify(k, bytecodes, randomness, True) def test_bytecode_incomplete(): - bytecodes = [ - unroll(bytes([7] * (2 ** k + 1)), rlc_store), - ] - verify(k, bytecodes, rlc_store, False) + bytecodes = [unroll(bytes([7] * (2 ** k + 1)), randomness)] + verify(k, bytecodes, randomness, False) def test_bytecode_multiple(): bytecodes = [ - unroll(bytes([]), rlc_store), - unroll(bytes([Opcode.PUSH32]), rlc_store), - unroll(bytes([Opcode.PUSH32, Opcode.ADD]), rlc_store), - unroll(bytes([Opcode.ADD, Opcode.PUSH32]), rlc_store), - unroll(bytes([Opcode.ADD, Opcode.PUSH32, Opcode.ADD]), rlc_store), + unroll(bytes([]), randomness), + unroll(bytes([Opcode.PUSH32]), randomness), + unroll(bytes([Opcode.PUSH32, Opcode.ADD]), randomness), + unroll(bytes([Opcode.ADD, Opcode.PUSH32]), randomness), + unroll(bytes([Opcode.ADD, Opcode.PUSH32, Opcode.ADD]), randomness), ] - verify(k, bytecodes, rlc_store, True) + verify(k, bytecodes, randomness, True) def test_bytecode_invalid_hash_data(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Change the hash on the first position invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): invalid.rows[idx] = (1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_index(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Start the index at 1 invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[idx] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) row = unrolled.rows[-1] - invalid.rows[-1] = (row[0] - 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[-1] = (row[0].value - 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_byte_data(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Change the first byte invalid = deepcopy(unrolled) row = unrolled.rows[0] invalid.rows[0] = (row[0], row[1], row[2], 9) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Change a byte on another position invalid = deepcopy(unrolled) row = unrolled.rows[5] invalid.rows[5] = (row[0], row[1], row[2], 6) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Set a byte value out of range invalid = deepcopy(unrolled) row = unrolled.rows[3] invalid.rows[3] = (row[0], row[1], row[2], 256) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_is_code(): @@ -167,24 +161,24 @@ def test_bytecode_invalid_is_code(): Opcode.PUSH6, ] ), - rlc_store, + randomness, ) - verify(k, [unrolled], rlc_store, True) + verify(k, [unrolled], randomness, True) # Mark the 3rd byte as code (is push data from the first PUSH1) invalid = deepcopy(unrolled) row = unrolled.rows[2] invalid.rows[2] = (row[0], row[1], 1, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Mark the 4rd byte as data (is code) invalid = deepcopy(unrolled) row = unrolled.rows[3] invalid.rows[3] = (row[0], row[1], 0, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Mark the 7th byte as code (is data for the PUSH7) invalid = deepcopy(unrolled) row = unrolled.rows[6] invalid.rows[6] = (row[0], row[1], 1, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) From 94abf3423e2d5a0a3c06066cd49a9a3141ed052e Mon Sep 17 00:00:00 2001 From: han0110 Date: Thu, 13 Jan 2022 02:51:18 +0800 Subject: [PATCH 2/8] feat: implement ExecutionState EndTx and EndBlock --- src/zkevm_specs/evm/execution/__init__.py | 4 ++ src/zkevm_specs/evm/execution/end_block.py | 30 +++++++++++++++ src/zkevm_specs/evm/execution/end_tx.py | 43 ++++++++++++++++++++++ src/zkevm_specs/evm/execution_state.py | 2 + src/zkevm_specs/evm/main.py | 5 +-- 5 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 src/zkevm_specs/evm/execution/end_block.py create mode 100644 src/zkevm_specs/evm/execution/end_tx.py diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 26921e554..429364adc 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -3,6 +3,8 @@ from ..execution_state import ExecutionState from .begin_tx import * +from .end_tx import * +from .end_block import * # Opcode's successful cases from .add import * @@ -15,6 +17,8 @@ EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { ExecutionState.BeginTx: begin_tx, + ExecutionState.EndTx: end_tx, + ExecutionState.EndBlock: end_block, ExecutionState.ADD: add, ExecutionState.CALLER: caller, ExecutionState.COINBASE: coinbase, diff --git a/src/zkevm_specs/evm/execution/end_block.py b/src/zkevm_specs/evm/execution/end_block.py new file mode 100644 index 000000000..4ba98a1e9 --- /dev/null +++ b/src/zkevm_specs/evm/execution/end_block.py @@ -0,0 +1,30 @@ +from ..instruction import Instruction, Transition +from ..table import CallContextFieldTag + + +# TODO: Introduce constrain_instance to constrain the equality between witness +# and public input, for total_tx and total_rw + + +def end_block(instruction: Instruction): + if instruction.is_last_step: + # Verify final step has tx_id identical to the tx amount in tx_table. + total_tx = instruction.call_context_lookup(CallContextFieldTag.TxId) + instruction.constrain_equal( + total_tx, + max([tx_id for tx_id, *_ in instruction.tables.tx_table]), + ) + + # Verify rw_counter counts to identical rw amount in rw_table to ensure + # there is no malicious insertion. + total_rw = instruction.curr.rw_counter + 1 # extra 1 from the tx_id lookup + instruction.constrain_equal( + total_rw, + len(instruction.tables.rw_table), + ) + else: + # Propagate rw_counter and call_id all the way down + instruction.constrain_step_state_transition( + rw_counter=Transition.same(), + call_id=Transition.same(), + ) diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py new file mode 100644 index 000000000..af28a63c6 --- /dev/null +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -0,0 +1,43 @@ +from ...util import N_BYTES_GAS +from ..execution_state import ExecutionState +from ..instruction import Instruction +from ..table import BlockContextFieldTag, CallContextFieldTag, TxContextFieldTag + + +def end_tx(instruction: Instruction): + tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) + + # Handle gas refund (refund is capped to gas_used // 2 in EIP 3529) + tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) + gas_used = tx_gas - instruction.curr.gas_left + capped_refund, _ = instruction.constant_divmod(gas_used, 2, N_BYTES_GAS) + accumulated_refund = instruction.tx_refund_read(tx_id) + refund = instruction.min(capped_refund, accumulated_refund, 8) + + # Add refund * gas_price back to caller's balance + tx_gas_price = instruction.tx_gas_price(tx_id) + value, carry = instruction.mul_word_by_u64(tx_gas_price, refund) + instruction.constrain_zero(carry) + tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) + instruction.add_balance(tx_caller_address, [value]) + + # Add gas_used * effective_tip to coinbase's balance + base_fee = instruction.block_context_lookup(BlockContextFieldTag.BaseFee) + effective_tip, _ = instruction.sub_word(tx_gas_price, base_fee) + reward, carry = instruction.mul_word_by_u64(effective_tip, gas_used) + instruction.constrain_zero(carry) + coinbase = instruction.block_context_lookup(BlockContextFieldTag.Coinbase) + instruction.add_balance(coinbase, [reward]) + + # Do step state transition for rw_counter + instruction.constrain_equal(instruction.next.rw_counter, instruction.curr.rw_counter + 7) + + # Go to next transaction + if instruction.next.execution_state == ExecutionState.BeginTx: + # Check next tx_id is increased by 1 + instruction.constrain_equal( + instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=instruction.next.rw_counter), + tx_id + 1, + ) + + # Or to ExecutionState.EndBlock diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index e51506235..2069a3398 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -10,6 +10,8 @@ class ExecutionState(IntEnum): """ BeginTx = auto() + EndTx = auto() + EndBlock = auto() # Opcode's successful cases STOP = auto() diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index 50dcff42e..d224afc31 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -39,7 +39,4 @@ def verify_step( raise NotImplementedError if instruction.is_last_step: - # Verify no malicious insertion - assert instruction.curr.rw_counter == len(instruction.tables.rw_table) - - # TODO: Verify final step has the tx_id identical to the amount in tx_table + instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.EndBlock) From d5a45cd8925b024e676577cef496c9da10def77f Mon Sep 17 00:00:00 2001 From: han0110 Date: Thu, 13 Jan 2022 18:07:38 +0800 Subject: [PATCH 3/8] feat: enforce ExecutionState transition constraint for BeginTx --- src/zkevm_specs/evm/execution/begin_tx.py | 14 +++- .../evm/execution/block_coinbase.py | 8 +- src/zkevm_specs/evm/execution/caller.py | 8 +- src/zkevm_specs/evm/instruction.py | 28 +++---- src/zkevm_specs/evm/main.py | 2 + tests/evm/test_begin_tx.py | 77 +++++++++++++++---- 6 files changed, 94 insertions(+), 43 deletions(-) diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index beeaa4ae3..907b4e9e6 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,7 +1,8 @@ -from ...util import GAS_COST_TX, GAS_COST_CREATION_TX +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH +from ..execution_state import ExecutionState from ..instruction import Instruction, Transition -from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag from ..precompiled import PrecompiledAddress +from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag def begin_tx(instruction: Instruction): @@ -56,7 +57,7 @@ def begin_tx(instruction: Instruction): ) if tx_is_create: - # TODO: Verify receiver address + # TODO: Verify created address # TODO: Decide what code_source should be (tx_id or hash of creation code) raise NotImplementedError elif tx_callee_address in list(PrecompiledAddress): @@ -91,3 +92,10 @@ def begin_tx(instruction: Instruction): gas_left=Transition.to(gas_left), state_write_counter=Transition.to(2), ) + + # Constrain either: + # - is_empty_code and is_to_end_tx + # - (not is_empty_code) and (not is_to_end_tx) + is_empty_code = instruction.is_equal(code_hash, instruction.int_to_rlc(EMPTY_CODE_HASH, 32)) + is_to_end_tx = instruction.is_equal(instruction.next.execution_state, ExecutionState.EndTx) + instruction.constrain_equal(is_empty_code + is_to_end_tx, 2 * is_empty_code * is_to_end_tx) diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 18c537f84..1cffca444 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -11,11 +11,9 @@ def coinbase(instruction: Instruction): # check block table for coinbase address instruction.constrain_equal( address, - instruction.bytes_to_rlc( - instruction.int_to_bytes( - instruction.block_context_lookup(BlockContextFieldTag.Coinbase), - 20, - ) + instruction.int_to_rlc( + instruction.block_context_lookup(BlockContextFieldTag.Coinbase), + 20, ), ) diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 1d0011394..d7cf9f209 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -11,11 +11,9 @@ def caller(instruction: Instruction): # check [rw_table, call_context] table for caller address instruction.constrain_equal( address, - instruction.bytes_to_rlc( - instruction.int_to_bytes( - instruction.call_context_lookup(CallContextFieldTag.CallerAddress), - 20, - ) + instruction.int_to_rlc( + instruction.call_context_lookup(CallContextFieldTag.CallerAddress), + 20, ), ) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 13cb8aa2b..1857b18ca 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -98,7 +98,7 @@ def constrain_bool(self, value: int): assert value in [0, 1], ConstraintUnsatFailure(f"Expected value to be a bool, but got {value}") def constrain_gas_left_not_underflow(self, gas_left: int): - self.int_to_bytes(gas_left, N_BYTES_GAS) + self.bytes_range_lookup(gas_left, N_BYTES_GAS) def constrain_step_state_transition(self, **kwargs: Transition): keys = set( @@ -214,7 +214,7 @@ def pair_select(self, value: int, lhs: int, rhs: int) -> Tuple[bool, bool]: def constant_divmod(self, numerator: int, denominator: int, n_bytes: int) -> Tuple[int, int]: quotient, remainder = divmod(numerator, denominator) - self.int_to_bytes(quotient, n_bytes) + self.bytes_range_lookup(quotient, n_bytes) return quotient, remainder def compare(self, lhs: int, rhs: int, n_bytes: int) -> Tuple[bool, bool]: @@ -283,15 +283,21 @@ def word_to_lo_hi(self, word: RLC) -> Tuple[Sequence[int], Sequence[int]]: assert len(word_le_bytes) == 32, "Expected RLC to contain 32 bytes" return self.bytes_to_int(word_le_bytes[:16]), self.bytes_to_int(word_le_bytes[16:]) - def bytes_to_rlc(self, bytes: Sequence[int]) -> RLC: - return RLC(bytes, self.randomness) + def int_to_rlc(self, value: int, n_bytes: int) -> RLC: + return RLC(value, self.randomness, n_bytes) - def bytes_to_int(self, bytes: Sequence[int]) -> int: - assert len(bytes) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + def bytes_to_int(self, value: Sequence[int]) -> int: + assert len(value) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" - return int.from_bytes(bytes, "little") + return int.from_bytes(value, "little") - def int_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: + def range_lookup(self, value: int, range: int): + self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), value, 0, 0]) + + def byte_range_lookup(self, value: int): + self.range_lookup(value, 256) + + def bytes_range_lookup(self, value: int, n_bytes: int) -> Sequence[int]: assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" try: @@ -299,12 +305,6 @@ def int_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: except OverflowError: raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") - def range_lookup(self, input: int, range: int): - self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), input, 0, 0]) - - def byte_range_lookup(self, input: int): - self.range_lookup(input, 256) - def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[int]) -> Array4: return self.tables.fixed_lookup([tag] + inputs) diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index d224afc31..3bd3d44ff 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -14,6 +14,8 @@ def verify_steps( begin_with_first_step: bool = False, end_with_final_step: bool = False, ): + # TODO: Enforce general ExecutionState transition constraint + for idx in range(len(steps) - 1): verify_step( Instruction( diff --git a/tests/evm/test_begin_tx.py b/tests/evm/test_begin_tx.py index 934871c7e..4f338d7cd 100644 --- a/tests/evm/test_begin_tx.py +++ b/tests/evm/test_begin_tx.py @@ -1,4 +1,5 @@ import pytest +from itertools import chain from zkevm_specs.evm import ( ExecutionState, @@ -11,51 +12,95 @@ CallContextFieldTag, Block, Transaction, + Account, Bytecode, ) from zkevm_specs.util import rand_fp, rand_address, rand_range, RLC +from zkevm_specs.util.hash import EMPTY_CODE_HASH + +RETURN_BYTECODE = Bytecode().return_(0, 0) +REVERT_BYTECODE = Bytecode().revert(0, 0) + +CALLEE_ADDRESS = 0xFF +CALLEE_WITH_NOTHING = Account(address=CALLEE_ADDRESS) +CALLEE_WITH_RETURN_BYTECODE = Account(address=CALLEE_ADDRESS, code_hash=RETURN_BYTECODE.hash()) +CALLEE_WITH_REVERT_BYTECODE = Account(address=CALLEE_ADDRESS, code_hash=REVERT_BYTECODE.hash()) TESTING_DATA = ( - # Transfer 1 ether, successfully - (Transaction(caller_address=0xFE, callee_address=0xFF, value=int(1e18)), True), - # Transfer 1 ether, tx reverts - (Transaction(caller_address=0xFE, callee_address=0xFF, value=int(1e18)), False), + # Transfer 1 ether to EOA, successfully + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_NOTHING, + True, + ), + # Transfer 1 ether to contract, successfully + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), + # Transfer 1 ether to contract, tx reverts + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_REVERT_BYTECODE, + False, + ), # Transfer random ether, successfully - (Transaction(caller_address=rand_address(), callee_address=rand_address(), value=rand_range(1e20)), True), + ( + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, value=rand_range(1e20)), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), # Transfer nothing with random gas_price, successfully ( - Transaction(caller_address=rand_address(), callee_address=rand_address(), gas_price=rand_range(42857142857143)), + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, gas_price=rand_range(42857142857143)), + CALLEE_WITH_RETURN_BYTECODE, True, ), # Transfer random ether, tx reverts - (Transaction(caller_address=rand_address(), callee_address=rand_address(), value=rand_range(1e20)), False), + ( + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, value=rand_range(1e20)), + CALLEE_WITH_REVERT_BYTECODE, + False, + ), # Transfer nothing with random gas_price, tx reverts ( - Transaction(caller_address=rand_address(), callee_address=rand_address(), gas_price=rand_range(42857142857143)), + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, gas_price=rand_range(42857142857143)), + CALLEE_WITH_REVERT_BYTECODE, False, ), # Transfer nothing with some calldata - (Transaction(caller_address=0xFE, callee_address=0xFF, gas=21080, call_data=bytes([1, 2, 3, 4, 0, 0, 0, 0])), True), + ( + Transaction( + caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=21080, call_data=bytes([1, 2, 3, 4, 0, 0, 0, 0]) + ), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), ) -@pytest.mark.parametrize("tx, result", TESTING_DATA) -def test_begin_tx(tx: Transaction, result: bool): +@pytest.mark.parametrize("tx, callee, result", TESTING_DATA) +def test_begin_tx(tx: Transaction, callee: Account, result: bool): randomness = rand_fp() rw_counter_end_of_reversion = 23 caller_balance_prev = int(1e20) - callee_balance_prev = 0 + callee_balance_prev = callee.balance caller_balance = caller_balance_prev - (tx.value + tx.gas * tx.gas_price) callee_balance = callee_balance_prev + tx.value - bytecode = Bytecode() - bytecode_hash = RLC(bytecode.hash(), randomness) + bytecode_hash = RLC(callee.code_hash, randomness) tables = Tables( block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), - bytecode_table=set(bytecode.table_assignments(randomness)), + bytecode_table=set( + chain( + RETURN_BYTECODE.table_assignments(randomness), + REVERT_BYTECODE.table_assignments(randomness), + ) + ), rw_table=set( [ (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), @@ -158,7 +203,7 @@ def test_begin_tx(tx: Transaction, result: bool): rw_counter=1, ), StepState( - execution_state=ExecutionState.STOP if result else ExecutionState.REVERT, + execution_state=ExecutionState.EndTx if callee.code_hash == EMPTY_CODE_HASH else ExecutionState.PUSH, rw_counter=17, call_id=1, is_root=True, From f6f27608aa0d6db4159fbde84300fec8bab96093 Mon Sep 17 00:00:00 2001 From: han0110 Date: Thu, 13 Jan 2022 20:09:51 +0800 Subject: [PATCH 4/8] feat: add test of EndTx --- src/zkevm_specs/evm/execution/end_tx.py | 28 +++--- src/zkevm_specs/evm/instruction.py | 19 +++- src/zkevm_specs/evm/main.py | 8 +- src/zkevm_specs/util/param.py | 2 + tests/evm/test_end_tx.py | 113 ++++++++++++++++++++++++ 5 files changed, 151 insertions(+), 19 deletions(-) create mode 100644 tests/evm/test_end_tx.py diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py index af28a63c6..a8b71b459 100644 --- a/src/zkevm_specs/evm/execution/end_tx.py +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -1,22 +1,22 @@ -from ...util import N_BYTES_GAS +from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED from ..execution_state import ExecutionState -from ..instruction import Instruction +from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag, CallContextFieldTag, TxContextFieldTag def end_tx(instruction: Instruction): tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) - # Handle gas refund (refund is capped to gas_used // 2 in EIP 3529) + # Handle gas refund (refund is capped to gas_used // MAX_REFUND_QUOTIENT_OF_GAS_USED in EIP 3529) tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) gas_used = tx_gas - instruction.curr.gas_left - capped_refund, _ = instruction.constant_divmod(gas_used, 2, N_BYTES_GAS) - accumulated_refund = instruction.tx_refund_read(tx_id) - refund = instruction.min(capped_refund, accumulated_refund, 8) + max_refund, _ = instruction.constant_divmod(gas_used, MAX_REFUND_QUOTIENT_OF_GAS_USED, N_BYTES_GAS) + refund = instruction.tx_refund_read(tx_id) + effective_refund = instruction.min(max_refund, refund, 8) - # Add refund * gas_price back to caller's balance + # Add effective_refund * gas_price back to caller's balance tx_gas_price = instruction.tx_gas_price(tx_id) - value, carry = instruction.mul_word_by_u64(tx_gas_price, refund) + value, carry = instruction.mul_word_by_u64(tx_gas_price, instruction.curr.gas_left + effective_refund) instruction.constrain_zero(carry) tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) instruction.add_balance(tx_caller_address, [value]) @@ -29,9 +29,6 @@ def end_tx(instruction: Instruction): coinbase = instruction.block_context_lookup(BlockContextFieldTag.Coinbase) instruction.add_balance(coinbase, [reward]) - # Do step state transition for rw_counter - instruction.constrain_equal(instruction.next.rw_counter, instruction.curr.rw_counter + 7) - # Go to next transaction if instruction.next.execution_state == ExecutionState.BeginTx: # Check next tx_id is increased by 1 @@ -40,4 +37,11 @@ def end_tx(instruction: Instruction): tx_id + 1, ) - # Or to ExecutionState.EndBlock + # Do step state transition for rw_counter + instruction.constrain_step_state_transition(rw_counter=Transition.delta(5)) + # Go to end of block + elif instruction.next.execution_state == ExecutionState.EndBlock: + # Do step state transition for rw_counter + instruction.constrain_step_state_transition(rw_counter=Transition.delta(4)) + else: + raise ValueError("Unreacheable") diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 1857b18ca..02527ff10 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -123,9 +123,8 @@ def constrain_step_state_transition(self, **kwargs: Transition): kwargs.keys() ), f"Invalid keys {list(set(kwargs.keys()).difference(keys))} for step state transition" - for key in keys: + for key, transition in kwargs.items(): curr, next = getattr(self.curr, key), getattr(self.next, key) - transition = kwargs.get(key, Transition.same()) if transition.kind == TransitionKind.Same: assert next == curr, ConstraintUnsatFailure(f"State {key} should be same as {curr}, but got {next}") elif transition.kind == TransitionKind.Delta: @@ -173,6 +172,7 @@ def step_state_transition_in_same_context( program_counter: Transition = Transition.same(), stack_pointer: Transition = Transition.same(), memory_size: Transition = Transition.same(), + state_write_counter: Transition = Transition.same(), dynamic_gas_cost: int = 0, ): self.responsible_opcode_lookup(opcode) @@ -184,8 +184,17 @@ def step_state_transition_in_same_context( rw_counter=rw_counter, program_counter=program_counter, stack_pointer=stack_pointer, - memory_size=memory_size, gas_left=Transition.delta(-gas_cost), + memory_size=memory_size, + state_write_counter=state_write_counter, + # Always stay same + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_source=Transition.same(), + last_callee_id=Transition.same(), + last_callee_return_data_offset=Transition.same(), + last_callee_return_data_length=Transition.same(), ) def sum(self, values: Sequence[int]) -> int: @@ -423,6 +432,10 @@ def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = No return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[5] + def tx_refund_read(self, tx_id) -> int: + row = self.rw_lookup(RW.Read, RWTableTag.TxRefund, [tx_id]) + return row[4] + def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> int: row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) return row[5] diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index 3bd3d44ff..f680a243e 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -12,19 +12,19 @@ def verify_steps( tables: Tables, steps: Sequence[StepState], begin_with_first_step: bool = False, - end_with_final_step: bool = False, + end_with_last_step: bool = False, ): # TODO: Enforce general ExecutionState transition constraint - for idx in range(len(steps) - 1): + for idx in range(len(steps) if end_with_last_step else len(steps) - 1): verify_step( Instruction( randomness=randomness, tables=tables, curr=steps[idx], - next=steps[idx + 1], + next=steps[idx + 1] if idx + 1 < len(steps) else None, is_first_step=begin_with_first_step and idx == 0, - is_last_step=end_with_final_step and idx == len(steps) - 2, + is_last_step=idx + 1 == len(steps), ), ) diff --git a/src/zkevm_specs/util/param.py b/src/zkevm_specs/util/param.py index 87ff0e194..41e2a6a39 100644 --- a/src/zkevm_specs/util/param.py +++ b/src/zkevm_specs/util/param.py @@ -32,6 +32,8 @@ # Gas stipend given if call with non-zero value GAS_STIPEND_CALL_WITH_VALUE = 2300 +# Quotient for max refund of gas used +MAX_REFUND_QUOTIENT_OF_GAS_USED = 5 # Denominator of quadratic part of memory expansion gas cost MEMORY_EXPANSION_QUAD_DENOMINATOR = 512 # Coefficient of linear part of memory expansion gas cost diff --git a/tests/evm/test_end_tx.py b/tests/evm/test_end_tx.py new file mode 100644 index 000000000..b69725937 --- /dev/null +++ b/tests/evm/test_end_tx.py @@ -0,0 +1,113 @@ +import pytest + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + RWTableTag, + RW, + AccountFieldTag, + CallContextFieldTag, + Block, + Transaction, +) +from zkevm_specs.util import rand_fp, RLC, EMPTY_CODE_HASH, MAX_REFUND_QUOTIENT_OF_GAS_USED + +CALLEE_ADDRESS = 0xFF + +TESTING_DATA = ( + # Tx with non-capped refund + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=27000, gas_price=int(2e9)), + 994, + 4800, + False, + ), + # Tx with capped refund + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=65000, gas_price=int(2e9)), + 3952, + 38400, + False, + ), + # Last tx + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=21000, gas_price=int(2e9)), + 0, + 0, + True, + ), +) + + +@pytest.mark.parametrize("tx, gas_left, refund, is_last_tx", TESTING_DATA) +def test_end_tx(tx: Transaction, gas_left: int, refund: int, is_last_tx: bool): + randomness = rand_fp() + + block = Block() + effective_refund = min(refund, (tx.gas - gas_left) // MAX_REFUND_QUOTIENT_OF_GAS_USED) + caller_balance_prev = int(1e18) - (tx.value + tx.gas * tx.gas_price) + caller_balance = caller_balance_prev + (gas_left + effective_refund) * tx.gas_price + coinbase_balance_prev = 0 + coinbase_balance = coinbase_balance_prev + (tx.gas - gas_left) * (tx.gas_price - block.base_fee) + + tables = Tables( + block_table=set(block.table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(), + rw_table=set( + [ + (17, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), + (18, RW.Read, RWTableTag.TxRefund, tx.id, refund, refund, 0, 0), + ( + 19, + RW.Write, + RWTableTag.Account, + tx.caller_address, + AccountFieldTag.Balance, + RLC(caller_balance, randomness), + RLC(caller_balance_prev, randomness), + 0, + ), + ( + 20, + RW.Write, + RWTableTag.Account, + block.coinbase, + AccountFieldTag.Balance, + RLC(coinbase_balance, randomness), + RLC(coinbase_balance_prev, randomness), + 0, + ), + ] + + ( + [] + if is_last_tx + else [(21, RW.Read, RWTableTag.CallContext, 22, CallContextFieldTag.TxId, tx.id + 1, 0, 0)] + ) + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.EndTx, + rw_counter=17, + call_id=1, + is_root=True, + is_create=False, + code_source=RLC(EMPTY_CODE_HASH, randomness), + program_counter=0, + stack_pointer=1024, + gas_left=gas_left, + state_write_counter=2, + ), + StepState( + execution_state=ExecutionState.EndBlock if is_last_tx else ExecutionState.BeginTx, + rw_counter=22 - is_last_tx, + ), + ], + ) From 0df591d28552632eda59aaee4f9b963a33718aeb Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 14 Jan 2022 00:34:22 +0800 Subject: [PATCH 5/8] feat: add test of EndBlock --- tests/evm/test_end_block.py | 57 +++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/evm/test_end_block.py diff --git a/tests/evm/test_end_block.py b/tests/evm/test_end_block.py new file mode 100644 index 000000000..21873f85b --- /dev/null +++ b/tests/evm/test_end_block.py @@ -0,0 +1,57 @@ +import pytest +from itertools import chain + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + RWTableTag, + RW, + CallContextFieldTag, + Block, + Transaction, +) +from zkevm_specs.util import rand_fp + +TESTING_DATA = (False, True) + + +@pytest.mark.parametrize("is_last_step", TESTING_DATA) +def test_end_block(is_last_step: bool): + randomness = rand_fp() + + tx = Transaction() + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(), + rw_table=set( + chain( + # dummy read/write for counting + [(i, *7 * [0]) for i in range(22)], + [(22, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0)] + if is_last_step + else [], + ) + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.EndBlock, + rw_counter=22, + call_id=1, + ), + StepState( + execution_state=ExecutionState.EndBlock, + rw_counter=22, + call_id=1, + ), + ], + end_with_last_step=is_last_step, + ) From 7f85c177272d26d7a26d6103fadfc6e1602bf7c7 Mon Sep 17 00:00:00 2001 From: han0110 Date: Sat, 15 Jan 2022 00:15:26 +0800 Subject: [PATCH 6/8] fix: fix SignByte to have correct sign byte --- src/zkevm_specs/evm/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 4c151c682..d19b18f5c 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -54,7 +54,7 @@ def table_assignments(self) -> Sequence[Array4]: elif self == FixedTableTag.Range1024: return [(self, i, 0, 0) for i in range(1024)] elif self == FixedTableTag.SignByte: - return [(self, i, (i & 1) * 0xFF, 0) for i in range(256)] + return [(self, i, (i >> 7) * 0xFF, 0) for i in range(256)] elif self == FixedTableTag.BitwiseAnd: return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] elif self == FixedTableTag.BitwiseOr: From dd061107981ee27ae6e2ab81f78d1b2bb075dc9e Mon Sep 17 00:00:00 2001 From: han0110 Date: Sat, 15 Jan 2022 22:46:06 +0800 Subject: [PATCH 7/8] refactor: rename fields of Block --- src/zkevm_specs/evm/table.py | 10 +++++----- src/zkevm_specs/evm/typing.py | 32 +++++++++++++++++++------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index d19b18f5c..116605a2f 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -107,11 +107,11 @@ class BlockContextFieldTag(IntEnum): Coinbase = auto() GasLimit = auto() - BlockNumber = auto() - Time = auto() + Number = auto() + Timestamp = auto() Difficulty = auto() BaseFee = auto() - BlockHash = auto() + HistoryHash = auto() class TxContextFieldTag(IntEnum): @@ -262,14 +262,14 @@ class Tables: # Each row in BlockTable contains: # - tag - # - block_number_or_zero (meaningful only for BlockHash, will be zero for other tags) + # - block_number_or_zero (meaningful only for HistoryHash, will be zero for other tags) # - value block_table: Set[Array3] # Each row in TxTable contains: # - tx_id # - tag - # - index_or_zero (meaningful only for CallData, will be zero for other tags) + # - call_data_index_or_zero (meaningful only for CallData, will be zero for other tags) # - value tx_table: Set[Array4] diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index f78a592d8..8e02738f7 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -22,32 +22,38 @@ class Block: coinbase: U160 + + # Gas needs a lot arithmetic operation or comparison in EVM circuit, so we + # assume gas limit in the near futuer will not exceed U64, to reduce the + # implementation complexity. gas_limit: U64 - block_number: U256 - time: U64 + + # For other fields, we follow the size defined in yellow paper for now. + number: U256 + timestamp: U256 difficulty: U256 base_fee: U256 - # history_hashes contains most recent 256 block hashes in history, where - # the lastest one is at history_hashes[-1]. + # It contains most recent 256 block hashes in history, where the lastest + # one is at history_hashes[-1]. history_hashes: Sequence[U256] def __init__( self, coinbase: U160 = 0x10, gas_limit: U64 = int(15e6), - block_number: U256 = 0, - time: U64 = 0, + number: U256 = 0, + timestamp: U256 = 0, difficulty: U256 = 0, base_fee: U256 = int(1e9), history_hashes: Sequence[U256] = [], ) -> None: - assert len(history_hashes) <= min(256, block_number) + assert len(history_hashes) <= min(256, number) self.coinbase = coinbase self.gas_limit = gas_limit - self.block_number = block_number - self.time = time + self.number = number + self.timestamp = timestamp self.difficulty = difficulty self.base_fee = base_fee self.history_hashes = history_hashes @@ -56,13 +62,13 @@ def table_assignments(self, randomness: int) -> Sequence[Array3]: return [ (BlockContextFieldTag.Coinbase, 0, self.coinbase), (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.BlockNumber, 0, RLC(self.block_number, randomness)), - (BlockContextFieldTag.Time, 0, RLC(self.time, randomness)), + (BlockContextFieldTag.Number, 0, RLC(self.number, randomness)), + (BlockContextFieldTag.Timestamp, 0, RLC(self.timestamp, randomness)), (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), ] + [ - (BlockContextFieldTag.BlockHash, self.block_number - idx - 1, RLC(block_hash, randomness)) - for idx, block_hash in enumerate(reversed(self.history_hashes)) + (BlockContextFieldTag.HistoryHash, self.number - idx - 1, RLC(history_hash, randomness)) + for idx, history_hash in enumerate(reversed(self.history_hashes)) ] From e69963c574242ceb10087be539f6ff04d639e1bf Mon Sep 17 00:00:00 2001 From: han0110 Date: Sun, 16 Jan 2022 12:57:22 +0800 Subject: [PATCH 8/8] refactor: use code and storage in Account instead of hash --- src/zkevm_specs/evm/typing.py | 25 ++++++++++++++++--------- tests/evm/test_begin_tx.py | 16 +++++----------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 8e02738f7..a07ed8241 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Iterator, Optional, Sequence +from typing import Any, Dict, Iterator, NewType, Optional, Sequence from functools import reduce from itertools import chain @@ -13,8 +13,6 @@ keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, - EMPTY_HASH, - EMPTY_TRIE_HASH, ) from .table import BlockContextFieldTag, TxContextFieldTag from .opcode import get_push_size, Opcode @@ -213,23 +211,32 @@ def __next__(self): return BytecodeIterator(RLC(self.hash(), randomness), self.code) +Storage = NewType("Storage", Dict[U256, U256]) + + class Account: address: U160 nonce: U256 balance: U256 - code_hash: U256 - storage_trie_hash: U256 + code: Bytecode + storage: Storage def __init__( self, address: U160 = 0, nonce: U256 = 0, balance: U256 = 0, - code_hash: U256 = EMPTY_HASH, - storage_trie_hash: U256 = EMPTY_TRIE_HASH, + code: Optional[Bytecode] = None, + storage: Optional[Storage] = None, ) -> None: self.address = address self.nonce = nonce self.balance = balance - self.code_hash = code_hash - self.storage_trie_hash = storage_trie_hash + self.code = Bytecode() if code is None else code + self.storage = dict() if storage is None else storage + + def code_hash(self) -> U256: + return self.code.hash() + + def storage_trie_hash(self) -> U256: + raise NotImplementedError("Trie has not been implemented") diff --git a/tests/evm/test_begin_tx.py b/tests/evm/test_begin_tx.py index 4f338d7cd..ac199750a 100644 --- a/tests/evm/test_begin_tx.py +++ b/tests/evm/test_begin_tx.py @@ -1,5 +1,4 @@ import pytest -from itertools import chain from zkevm_specs.evm import ( ExecutionState, @@ -23,8 +22,8 @@ CALLEE_ADDRESS = 0xFF CALLEE_WITH_NOTHING = Account(address=CALLEE_ADDRESS) -CALLEE_WITH_RETURN_BYTECODE = Account(address=CALLEE_ADDRESS, code_hash=RETURN_BYTECODE.hash()) -CALLEE_WITH_REVERT_BYTECODE = Account(address=CALLEE_ADDRESS, code_hash=REVERT_BYTECODE.hash()) +CALLEE_WITH_RETURN_BYTECODE = Account(address=CALLEE_ADDRESS, code=RETURN_BYTECODE) +CALLEE_WITH_REVERT_BYTECODE = Account(address=CALLEE_ADDRESS, code=REVERT_BYTECODE) TESTING_DATA = ( # Transfer 1 ether to EOA, successfully @@ -90,17 +89,12 @@ def test_begin_tx(tx: Transaction, callee: Account, result: bool): caller_balance = caller_balance_prev - (tx.value + tx.gas * tx.gas_price) callee_balance = callee_balance_prev + tx.value - bytecode_hash = RLC(callee.code_hash, randomness) + bytecode_hash = RLC(callee.code_hash(), randomness) tables = Tables( block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), - bytecode_table=set( - chain( - RETURN_BYTECODE.table_assignments(randomness), - REVERT_BYTECODE.table_assignments(randomness), - ) - ), + bytecode_table=set(callee.code.table_assignments(randomness)), rw_table=set( [ (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), @@ -203,7 +197,7 @@ def test_begin_tx(tx: Transaction, callee: Account, result: bool): rw_counter=1, ), StepState( - execution_state=ExecutionState.EndTx if callee.code_hash == EMPTY_CODE_HASH else ExecutionState.PUSH, + execution_state=ExecutionState.EndTx if callee.code_hash() == EMPTY_CODE_HASH else ExecutionState.PUSH, rw_counter=17, call_id=1, is_root=True,