diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 9d5f9537e..4d2b085a7 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -1,6 +1,6 @@ from typing import Sequence, Union, Tuple, Set from collections import namedtuple -from .util import keccak256, fp_add, fp_mul, RLCStore +from .util import keccak256, fp_add, fp_mul, RLC from .evm.opcode import get_push_size from .encoding import U8, U256, is_circuit_code @@ -87,7 +87,7 @@ def check_bytecode_row( # Populate the circuit matrix -def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_store: RLCStore): +def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], randomness: int): # All rows are usable in this emulation last_row_offset = 2 ** k - 1 @@ -103,7 +103,7 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_s push_data_left = byte_push_size if is_code else push_data_left - 1 # Add the byte to the accumulator - hash_rlc = fp_add(fp_mul(hash_rlc, rlc_store.randomness), row[2]) + hash_rlc = fp_add(fp_mul(hash_rlc, randomness), row[2]) # Set the data for this row rows.append( @@ -162,10 +162,10 @@ def assign_push_table(): # Generate keccak table -def assign_keccak_table(bytecodes: Sequence[bytes], rlc_store: RLCStore): +def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int): keccak_table = [] for bytecode in bytecodes: - hash = rlc_store.to_rlc(keccak256(bytecode), 32) - rlc = rlc_store.to_rlc(list(reversed(bytecode))) + hash = RLC(bytes(reversed(keccak256(bytecode))), randomness) + rlc = RLC(bytes(reversed(bytecode)), randomness, len(bytecode)) keccak_table.append((rlc, len(bytecode), hash)) return keccak_table diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 0306dd1dd..429364adc 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -1,11 +1,28 @@ +from typing import Callable, Dict + +from ..execution_state import ExecutionState + from .begin_tx import * +from .end_tx import * +from .end_block import * # Opcode's successful cases from .add import * -from .push import * from .jump import * from .jumpi import * +from .push import * from .block_coinbase import * from .caller import * -# Error cases + +EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = { + ExecutionState.BeginTx: begin_tx, + ExecutionState.EndTx: end_tx, + ExecutionState.EndBlock: end_block, + ExecutionState.ADD: add, + ExecutionState.CALLER: caller, + ExecutionState.COINBASE: coinbase, + ExecutionState.JUMP: jump, + ExecutionState.JUMPI: jumpi, + ExecutionState.PUSH: push, +} diff --git a/src/zkevm_specs/evm/execution/add.py b/src/zkevm_specs/evm/execution/add.py index 9f040acef..8bd6f2e54 100644 --- a/src/zkevm_specs/evm/execution/add.py +++ b/src/zkevm_specs/evm/execution/add.py @@ -16,7 +16,7 @@ def add(instruction: Instruction): instruction.select(is_sub, a, c), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(3), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index 48dc2b29c..907b4e9e6 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,41 +1,45 @@ +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH +from ..execution_state import ExecutionState from ..instruction import Instruction, Transition -from ..table import CallContextFieldTag, TxContextFieldTag, RW, AccountFieldTag from ..precompiled import PrecompiledAddress +from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag -def begin_tx(instruction: Instruction, is_first_step: bool = False): - instruction.constrain_equal(instruction.curr.call_id, instruction.curr.rw_counter) +def begin_tx(instruction: Instruction): + call_id = instruction.curr.rw_counter - tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) - rw_counter_end_of_reversion = instruction.call_context_lookup(CallContextFieldTag.RWCounterEndOfReversion) - is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent) + tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=call_id) + rw_counter_end_of_reversion = instruction.call_context_lookup( + CallContextFieldTag.RwCounterEndOfReversion, call_id=call_id + ) + is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent, call_id=call_id) - if is_first_step: + if instruction.is_first_step: instruction.constrain_equal(instruction.curr.rw_counter, 1) instruction.constrain_equal(tx_id, 1) - tx_caller_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CallerAddress) - tx_callee_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CalleeAddress) - tx_is_create = instruction.tx_lookup(tx_id, TxContextFieldTag.IsCreate) - tx_value = instruction.tx_lookup(tx_id, TxContextFieldTag.Value) - tx_call_data_length = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataLength) + tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) + tx_callee_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CalleeAddress) + tx_is_create = instruction.tx_context_lookup(tx_id, TxContextFieldTag.IsCreate) + tx_value = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Value) + tx_call_data_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength) # Verify nonce - tx_nonce = instruction.tx_lookup(tx_id, TxContextFieldTag.Nonce) + tx_nonce = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Nonce) nonce, nonce_prev = instruction.account_write(tx_caller_address, AccountFieldTag.Nonce) instruction.constrain_equal(tx_nonce, nonce_prev) instruction.constrain_equal(nonce, nonce_prev + 1) # TODO: Implement EIP 1559 (currently it supports legacy transaction format) # Calculate gas fee - tx_gas = instruction.tx_lookup(tx_id, TxContextFieldTag.Gas) - tx_gas_price = instruction.tx_lookup(tx_id, TxContextFieldTag.GasPrice) + tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) + tx_gas_price = instruction.tx_gas_price(tx_id) gas_fee, carry = instruction.mul_word_by_u64(tx_gas_price, tx_gas) instruction.constrain_zero(carry) # TODO: Handle gas cost of tx level access list (EIP 2930) - tx_call_data_gas_cost = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataGasCost) - gas_left = tx_gas - (53000 if tx_is_create else 21000) - tx_call_data_gas_cost + tx_call_data_gas_cost = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataGasCost) + gas_left = tx_gas - (GAS_COST_CREATION_TX if tx_is_create else GAS_COST_TX) - tx_call_data_gas_cost instruction.constrain_gas_left_not_underflow(gas_left) # Prepare access list of caller and callee @@ -47,27 +51,27 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False): tx_caller_address, tx_callee_address, tx_value, - gas_fee=gas_fee, - is_persistent=is_persistent, - rw_counter_end_of_reversion=rw_counter_end_of_reversion, + gas_fee, + is_persistent, + rw_counter_end_of_reversion, ) if tx_is_create: - # TODO: Verify receiver address - # TODO: Set opcode_source to tx_id + # TODO: Verify created address + # TODO: Decide what code_source should be (tx_id or hash of creation code) raise NotImplementedError elif tx_callee_address in list(PrecompiledAddress): # TODO: Handle precompile raise NotImplementedError else: - code_hash, _ = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash) + code_hash = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash) # Setup next call's context # Note that: - # - CallerCallId, ReturnDataOffset, ReturnDataLength, Result - # should never be used in root call, so unnecessary to check - # - TxId is propagated from previous step or constraint to 1 if is_first_step - # - IsPersistent will be verified in the end of tx + # - CallerId, ReturnDataOffset, ReturnDataLength + # should never be used in root call, so unnecessary to be checked + # - TxId is checked from previous step or constraint to 1 if is_first_step + # - IsSuccess, IsPersistent will be verified in the end of tx for (tag, value) in [ (CallContextFieldTag.Depth, 1), (CallContextFieldTag.CallerAddress, tx_caller_address), @@ -77,14 +81,21 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False): (CallContextFieldTag.Value, tx_value), (CallContextFieldTag.IsStatic, False), ]: - instruction.constrain_equal(instruction.call_context_lookup(tag), value) + instruction.constrain_equal(instruction.call_context_lookup(tag, call_id=call_id), value) - instruction.constrain_new_context_state_transition( + instruction.step_state_transition_to_new_context( rw_counter=Transition.delta(16), - call_id=Transition.persistent(), + call_id=Transition.to(call_id), is_root=Transition.to(True), is_create=Transition.to(False), - opcode_source=Transition.to(code_hash), + code_source=Transition.to(code_hash), gas_left=Transition.to(gas_left), state_write_counter=Transition.to(2), ) + + # Constrain either: + # - is_empty_code and is_to_end_tx + # - (not is_empty_code) and (not is_to_end_tx) + is_empty_code = instruction.is_equal(code_hash, instruction.int_to_rlc(EMPTY_CODE_HASH, 32)) + is_to_end_tx = instruction.is_equal(instruction.next.execution_state, ExecutionState.EndTx) + instruction.constrain_equal(is_empty_code + is_to_end_tx, 2 * is_empty_code * is_to_end_tx) diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 3b47cdc28..1cffca444 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -11,10 +11,13 @@ def coinbase(instruction: Instruction): # check block table for coinbase address instruction.constrain_equal( address, - instruction.bytes_to_rlc(instruction.int_to_bytes(instruction.block_lookup(BlockContextFieldTag.Coinbase), 20)), + instruction.int_to_rlc( + instruction.block_context_lookup(BlockContextFieldTag.Coinbase), + 20, + ), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 1750fb905..d7cf9f209 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -11,15 +11,13 @@ def caller(instruction: Instruction): # check [rw_table, call_context] table for caller address instruction.constrain_equal( address, - instruction.bytes_to_rlc( - instruction.int_to_bytes( - instruction.call_context_lookup(CallContextFieldTag.CallerAddress), - 20, - ) + instruction.int_to_rlc( + instruction.call_context_lookup(CallContextFieldTag.CallerAddress), + 20, ), ) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(2), program_counter=Transition.delta(1), diff --git a/src/zkevm_specs/evm/execution/end_block.py b/src/zkevm_specs/evm/execution/end_block.py new file mode 100644 index 000000000..4ba98a1e9 --- /dev/null +++ b/src/zkevm_specs/evm/execution/end_block.py @@ -0,0 +1,30 @@ +from ..instruction import Instruction, Transition +from ..table import CallContextFieldTag + + +# TODO: Introduce constrain_instance to constrain the equality between witness +# and public input, for total_tx and total_rw + + +def end_block(instruction: Instruction): + if instruction.is_last_step: + # Verify final step has tx_id identical to the tx amount in tx_table. + total_tx = instruction.call_context_lookup(CallContextFieldTag.TxId) + instruction.constrain_equal( + total_tx, + max([tx_id for tx_id, *_ in instruction.tables.tx_table]), + ) + + # Verify rw_counter counts to identical rw amount in rw_table to ensure + # there is no malicious insertion. + total_rw = instruction.curr.rw_counter + 1 # extra 1 from the tx_id lookup + instruction.constrain_equal( + total_rw, + len(instruction.tables.rw_table), + ) + else: + # Propagate rw_counter and call_id all the way down + instruction.constrain_step_state_transition( + rw_counter=Transition.same(), + call_id=Transition.same(), + ) diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py new file mode 100644 index 000000000..a8b71b459 --- /dev/null +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -0,0 +1,47 @@ +from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED +from ..execution_state import ExecutionState +from ..instruction import Instruction, Transition +from ..table import BlockContextFieldTag, CallContextFieldTag, TxContextFieldTag + + +def end_tx(instruction: Instruction): + tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) + + # Handle gas refund (refund is capped to gas_used // MAX_REFUND_QUOTIENT_OF_GAS_USED in EIP 3529) + tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) + gas_used = tx_gas - instruction.curr.gas_left + max_refund, _ = instruction.constant_divmod(gas_used, MAX_REFUND_QUOTIENT_OF_GAS_USED, N_BYTES_GAS) + refund = instruction.tx_refund_read(tx_id) + effective_refund = instruction.min(max_refund, refund, 8) + + # Add effective_refund * gas_price back to caller's balance + tx_gas_price = instruction.tx_gas_price(tx_id) + value, carry = instruction.mul_word_by_u64(tx_gas_price, instruction.curr.gas_left + effective_refund) + instruction.constrain_zero(carry) + tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) + instruction.add_balance(tx_caller_address, [value]) + + # Add gas_used * effective_tip to coinbase's balance + base_fee = instruction.block_context_lookup(BlockContextFieldTag.BaseFee) + effective_tip, _ = instruction.sub_word(tx_gas_price, base_fee) + reward, carry = instruction.mul_word_by_u64(effective_tip, gas_used) + instruction.constrain_zero(carry) + coinbase = instruction.block_context_lookup(BlockContextFieldTag.Coinbase) + instruction.add_balance(coinbase, [reward]) + + # Go to next transaction + if instruction.next.execution_state == ExecutionState.BeginTx: + # Check next tx_id is increased by 1 + instruction.constrain_equal( + instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=instruction.next.rw_counter), + tx_id + 1, + ) + + # Do step state transition for rw_counter + instruction.constrain_step_state_transition(rw_counter=Transition.delta(5)) + # Go to end of block + elif instruction.next.execution_state == ExecutionState.EndBlock: + # Do step state transition for rw_counter + instruction.constrain_step_state_transition(rw_counter=Transition.delta(4)) + else: + raise ValueError("Unreacheable") diff --git a/src/zkevm_specs/evm/execution/jump.py b/src/zkevm_specs/evm/execution/jump.py index 454feb6c9..c21d668a3 100644 --- a/src/zkevm_specs/evm/execution/jump.py +++ b/src/zkevm_specs/evm/execution/jump.py @@ -1,3 +1,4 @@ +from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -11,13 +12,13 @@ def jump(instruction: Instruction): dest = instruction.stack_pop() # Get `dest` raw value in max 8 bytes - dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8)) + dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER) # Verify `dest` is code within byte code table # assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True) instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True)) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.to(dest_value), diff --git a/src/zkevm_specs/evm/execution/jumpi.py b/src/zkevm_specs/evm/execution/jumpi.py index 3cba2b651..6ff768c5c 100644 --- a/src/zkevm_specs/evm/execution/jumpi.py +++ b/src/zkevm_specs/evm/execution/jumpi.py @@ -1,3 +1,4 @@ +from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -16,12 +17,12 @@ def jumpi(instruction: Instruction): pc_diff = 1 else: # Get `dest` raw value in max 8 bytes - dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8)) + dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER) pc_diff = dest_value - instruction.curr.program_counter # assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True) instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True)) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(2), program_counter=Transition.delta(pc_diff), diff --git a/src/zkevm_specs/evm/execution/push.py b/src/zkevm_specs/evm/execution/push.py index de41bbb8e..3f01f6265 100644 --- a/src/zkevm_specs/evm/execution/push.py +++ b/src/zkevm_specs/evm/execution/push.py @@ -8,17 +8,17 @@ def push(instruction: Instruction): num_additional_pushed = num_pushed - 1 value = instruction.stack_push() - value_bytes = instruction.rlc_to_bytes(value, 32) + value_le_bytes = instruction.rlc_to_le_bytes(value) selectors = instruction.continuous_selectors(num_additional_pushed, 31) for idx in range(32): index = instruction.curr.program_counter + num_pushed - idx if idx == 0 or selectors[idx - 1]: - instruction.constrain_equal(value_bytes[idx], instruction.opcode_lookup_at(index, False)) + instruction.constrain_equal(value_le_bytes[idx], instruction.opcode_lookup_at(index, False)) else: - instruction.constrain_zero(value_bytes[idx]) + instruction.constrain_zero(value_le_bytes[idx]) - instruction.constrain_same_context_state_transition( + instruction.step_state_transition_in_same_context( opcode, rw_counter=Transition.delta(1), program_counter=Transition.delta(1 + num_pushed), diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index e51506235..2069a3398 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -10,6 +10,8 @@ class ExecutionState(IntEnum): """ BeginTx = auto() + EndTx = auto() + EndBlock = auto() # Opcode's successful cases STOP = auto() diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index c8bf630e5..02527ff10 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -2,7 +2,17 @@ from enum import IntEnum, auto from typing import Optional, Sequence, Tuple, Union -from ..util import Array4, Array8, linear_combine, RLCStore, MAX_N_BYTES, N_BYTES_GAS +from ..util import ( + Array4, + Array8, + RLC, + MAX_N_BYTES, + N_BYTES_MEMORY_ADDRESS, + N_BYTES_MEMORY_SIZE, + N_BYTES_GAS, + MEMORY_EXPANSION_QUAD_DENOMINATOR, + MEMORY_EXPANSION_LINEAR_COEFF, +) from .opcode import Opcode from .step import StepState from .table import ( @@ -23,7 +33,7 @@ def __init__(self, message: str) -> None: class TransitionKind(IntEnum): - Persistent = auto() + Same = auto() Delta = auto() To = auto() @@ -36,8 +46,8 @@ def __init__(self, kind: TransitionKind, value: Optional[int] = None) -> None: self.kind = kind self.value = value - def persistent() -> Transition: - return Transition(TransitionKind.Persistent) + def same() -> Transition: + return Transition(TransitionKind.Same) def delta(delta: int): return Transition(TransitionKind.Delta, delta) @@ -47,85 +57,103 @@ def to(to: int): class Instruction: - rlc_store: RLCStore + randomness: int tables: Tables curr: StepState next: StepState + # meta information + is_first_step: bool + is_last_step: bool + # helper numbers - cell_offset: int = 0 rw_counter_offset: int = 0 program_counter_offset: int = 0 stack_pointer_offset: int = 0 state_write_counter_offset: int = 0 - def __init__(self, rlc_store: RLCStore, tables: Tables, curr: StepState, next: StepState) -> None: - self.rlc_store = rlc_store + def __init__( + self, + randomness: int, + tables: Tables, + curr: StepState, + next: StepState, + is_first_step: bool, + is_last_step: bool, + ) -> None: + self.randomness = randomness self.tables = tables self.curr = curr self.next = next + self.is_first_step = is_first_step + self.is_last_step = is_last_step def constrain_zero(self, value: int): - assert value == 0 + assert value == 0, ConstraintUnsatFailure(f"Expected value to be 0, but got {value}") def constrain_equal(self, lhs: int, rhs: int): - self.constrain_zero(lhs - rhs) + assert lhs == rhs, ConstraintUnsatFailure(f"Expected values to be equal, but got {lhs} and {rhs}") def constrain_bool(self, value: int): - assert value in [0, 1] + assert value in [0, 1], ConstraintUnsatFailure(f"Expected value to be a bool, but got {value}") def constrain_gas_left_not_underflow(self, gas_left: int): - self.int_to_bytes(gas_left, N_BYTES_GAS) - - def constrain_state_transition(self, **kwargs: Transition): - for key in [ - "rw_counter", - "call_id", - "is_root", - "is_create", - "opcode_source", - "program_counter", - "stack_pointer", - "gas_left", - "memory_size", - "state_write_counter", - "last_callee_id", - "last_callee_return_data_offset", - "last_callee_return_data_length", - ]: + self.bytes_range_lookup(gas_left, N_BYTES_GAS) + + def constrain_step_state_transition(self, **kwargs: Transition): + keys = set( + [ + "rw_counter", + "call_id", + "is_root", + "is_create", + "code_source", + "program_counter", + "stack_pointer", + "gas_left", + "memory_size", + "state_write_counter", + "last_callee_id", + "last_callee_return_data_offset", + "last_callee_return_data_length", + ] + ) + + assert keys.issuperset( + kwargs.keys() + ), f"Invalid keys {list(set(kwargs.keys()).difference(keys))} for step state transition" + + for key, transition in kwargs.items(): curr, next = getattr(self.curr, key), getattr(self.next, key) - transition = kwargs.get(key, Transition.persistent()) - if transition.kind == TransitionKind.Persistent: - assert next == curr, ConstraintUnsatFailure( - f"state {key} should be persistent as {curr}, but got {next}" - ) + if transition.kind == TransitionKind.Same: + assert next == curr, ConstraintUnsatFailure(f"State {key} should be same as {curr}, but got {next}") elif transition.kind == TransitionKind.Delta: assert next == curr + transition.value, ConstraintUnsatFailure( - f"state {key} should transit to ${curr + transition.value}, but got {next}" + f"State {key} should transit to {curr + transition.value}, but got {next}" ) elif transition.kind == TransitionKind.To: assert next == transition.value, ConstraintUnsatFailure( - f"state {key} should transit to ${transition.value}, but got {next}" + f"State {key} should transit to {transition.value}, but got {next}" ) else: raise ValueError("unreacheable") - def constrain_new_context_state_transition( + def step_state_transition_to_new_context( self, rw_counter: Transition, call_id: Transition, is_root: Transition, is_create: Transition, - opcode_source: Transition, + code_source: Transition, gas_left: Transition, state_write_counter: Transition, ): - self.constrain_state_transition( + self.constrain_step_state_transition( rw_counter=rw_counter, call_id=call_id, is_root=is_root, is_create=is_create, - opcode_source=opcode_source, + code_source=code_source, gas_left=gas_left, state_write_counter=state_write_counter, # Initailization unconditionally @@ -137,30 +165,51 @@ def constrain_new_context_state_transition( last_callee_return_data_length=Transition.to(0), ) - def constrain_same_context_state_transition( + def step_state_transition_in_same_context( self, opcode: int, - rw_counter: Transition = Transition.persistent(), - program_counter: Transition = Transition.persistent(), - stack_pointer: Transition = Transition.persistent(), - memory_size: Transition = Transition.persistent(), + rw_counter: Transition = Transition.same(), + program_counter: Transition = Transition.same(), + stack_pointer: Transition = Transition.same(), + memory_size: Transition = Transition.same(), + state_write_counter: Transition = Transition.same(), dynamic_gas_cost: int = 0, ): - gas_cost = Opcode(opcode).constant_gas_cost() + dynamic_gas_cost + self.responsible_opcode_lookup(opcode) + gas_cost = Opcode(opcode).constant_gas_cost() + dynamic_gas_cost self.constrain_gas_left_not_underflow(self.curr.gas_left - gas_cost) - self.constrain_state_transition( + + self.constrain_step_state_transition( rw_counter=rw_counter, program_counter=program_counter, stack_pointer=stack_pointer, - memory_size=memory_size, gas_left=Transition.delta(-gas_cost), + memory_size=memory_size, + state_write_counter=state_write_counter, + # Always stay same + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_source=Transition.same(), + last_callee_id=Transition.same(), + last_callee_return_data_offset=Transition.same(), + last_callee_return_data_length=Transition.same(), ) - def is_zero(self, value: int) -> bool: + def sum(self, values: Sequence[int]) -> int: + return sum(values) + + def is_zero(self, value: Union[int, RLC]) -> bool: + if isinstance(value, RLC): + value = value.value return value == 0 - def is_equal(self, lhs: int, rhs: int) -> bool: + def is_equal(self, lhs: Union[int, RLC], rhs: Union[int, RLC]) -> bool: + if isinstance(lhs, RLC): + lhs = lhs.value + if isinstance(rhs, RLC): + rhs = rhs.value return self.is_zero(lhs - rhs) def continuous_selectors(self, t: int, n: int) -> Sequence[int]: @@ -172,22 +221,38 @@ def select(self, condition: bool, when_true: int, when_false: int) -> int: def pair_select(self, value: int, lhs: int, rhs: int) -> Tuple[bool, bool]: return value == lhs, value == rhs - def add_words(self, addends: Sequence[int]) -> Tuple[int, int]: - def rlc_to_lo_hi(rlc: int) -> Tuple[Sequence[int], Sequence[int]]: - bytes = self.rlc_to_bytes(rlc, 32) - return self.bytes_to_int(bytes[:16]), self.bytes_to_int(bytes[16:]) + def constant_divmod(self, numerator: int, denominator: int, n_bytes: int) -> Tuple[int, int]: + quotient, remainder = divmod(numerator, denominator) + self.bytes_range_lookup(quotient, n_bytes) + return quotient, remainder + + def compare(self, lhs: int, rhs: int, n_bytes: int) -> Tuple[bool, bool]: + assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + + return lhs < rhs, lhs == rhs + + def min(self, lhs: int, rhs: int, n_bytes: int) -> int: + lt, _ = self.compare(lhs, rhs, n_bytes) + return self.select(lt, lhs, rhs) + + def max(self, lhs: int, rhs: int, n_bytes: int) -> int: + lt, _ = self.compare(lhs, rhs, n_bytes) + return self.select(lt, rhs, lhs) - addends_lo, addends_hi = list(zip(*map(rlc_to_lo_hi, addends))) - carry_lo, sum_lo = divmod(sum(addends_lo), 1 << 128) - carry_hi, sum_hi = divmod(sum(addends_hi) + carry_lo, 1 << 128) + def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, int]: + addends_lo, addends_hi = list(zip(*map(self.word_to_lo_hi, addends))) + + carry_lo, sum_lo = divmod(self.sum(addends_lo), 1 << 128) + carry_hi, sum_hi = divmod(self.sum(addends_hi) + carry_lo, 1 << 128) sum_bytes = sum_lo.to_bytes(16, "little") + sum_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(sum_bytes), carry_hi + return RLC(sum_bytes, self.randomness), carry_hi + + def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: + minuend_lo, minuend_hi = self.word_to_lo_hi(minuend) + subtrahend_lo, subtrahend_hi = self.word_to_lo_hi(subtrahend) - def sub_word(self, minuend: int, subtrahend: int) -> Tuple[int, bool]: - minuend_lo, minuend_hi = self.rlc_to_lo_hi(minuend) - subtrahend_lo, subtrahend_hi = self.rlc_to_lo_hi(subtrahend) borrow_lo = minuend_lo < subtrahend_lo diff_lo = minuend_lo - subtrahend_lo + (1 << 128 if borrow_lo else 0) borrow_hi = minuend_hi < subtrahend_hi + borrow_lo @@ -195,56 +260,81 @@ def sub_word(self, minuend: int, subtrahend: int) -> Tuple[int, bool]: diff_bytes = diff_lo.to_bytes(16, "little") + diff_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(diff_bytes), borrow_hi - - def mul_word_by_u64(self, multiplicand: int, multiplier: int) -> Tuple[int, int]: - multiplicand_bytes = self.rlc_to_bytes(multiplicand, 32) + return RLC(diff_bytes, self.randomness), borrow_hi - multiplicand_lo = self.bytes_to_int(multiplicand_bytes[:16]) - multiplicand_hi = self.bytes_to_int(multiplicand_bytes[16:]) + def mul_word_by_u64(self, multiplicand: RLC, multiplier: int) -> Tuple[RLC, int]: + multiplicand_lo, multiplicand_hi = self.word_to_lo_hi(multiplicand) quotient_lo, product_lo = divmod(multiplicand_lo * multiplier, 1 << 128) quotient_hi, product_hi = divmod(multiplicand_hi * multiplier + quotient_lo, 1 << 128) product_bytes = product_lo.to_bytes(16, "little") + product_hi.to_bytes(16, "little") - return self.rlc_store.to_rlc(product_bytes), quotient_hi + return RLC(product_bytes, self.randomness), quotient_hi + + def rlc_to_le_bytes(self, rlc: RLC) -> Sequence[int]: + return rlc.le_bytes + + def rlc_to_int_unchecked(self, rlc: RLC, n_bytes: int) -> int: + rlc_le_bytes = self.rlc_to_le_bytes(rlc) + return self.bytes_to_int(rlc_le_bytes[:n_bytes]), self.is_zero(self.sum(rlc_le_bytes[n_bytes:])) - def rlc_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: - bytes = self.rlc_store.to_bytes(value) - if len(bytes) > n_bytes and any(bytes[n_bytes:]): - raise ConstraintUnsatFailure(f"{value} is too many bytes to fit {n_bytes} bytes") - return list(bytes) + (n_bytes - len(bytes)) * [0] + def rlc_to_int_exact(self, rlc: RLC, n_bytes: int) -> int: + rlc_le_bytes = self.rlc_to_le_bytes(rlc) - def bytes_to_rlc(self, bytes: Sequence[int]) -> int: - return self.rlc_store.to_rlc(bytes) + if sum(rlc_le_bytes[n_bytes:]) > 0: + raise ConstraintUnsatFailure(f"Value {rlc} has too many bytes to fit {n_bytes} bytes") - def bytes_to_int(self, bytes: Sequence[int]) -> int: - assert len(bytes) <= MAX_N_BYTES, "too many bytes to composite an integer in field" - return linear_combine(bytes, 256) + return self.bytes_to_int(rlc_le_bytes[:n_bytes]) + + def word_to_lo_hi(self, word: RLC) -> Tuple[Sequence[int], Sequence[int]]: + word_le_bytes = self.rlc_to_le_bytes(word) + assert len(word_le_bytes) == 32, "Expected RLC to contain 32 bytes" + return self.bytes_to_int(word_le_bytes[:16]), self.bytes_to_int(word_le_bytes[16:]) + + def int_to_rlc(self, value: int, n_bytes: int) -> RLC: + return RLC(value, self.randomness, n_bytes) + + def bytes_to_int(self, value: Sequence[int]) -> int: + assert len(value) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + + return int.from_bytes(value, "little") + + def range_lookup(self, value: int, range: int): + self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), value, 0, 0]) + + def byte_range_lookup(self, value: int): + self.range_lookup(value, 256) + + def bytes_range_lookup(self, value: int, n_bytes: int) -> Sequence[int]: + assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" - def int_to_bytes(self, value: int, n_bytes: int) -> Sequence[int]: - assert n_bytes <= MAX_N_BYTES, "too many bytes to composite an integer in field" try: return value.to_bytes(n_bytes, "little") except OverflowError: - raise ConstraintUnsatFailure(f"{value} is too many bytes to fit {n_bytes} bytes") - - def byte_range_lookup(self, input: int): - self.tables.fixed_lookup([FixedTableTag.Range256, input, 0, 0]) + raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[int]) -> Array4: return self.tables.fixed_lookup([tag] + inputs) - def block_lookup(self, tag: BlockContextFieldTag, index: int = 0) -> int: + def block_context_lookup(self, tag: BlockContextFieldTag, index: int = 0) -> int: return self.tables.block_lookup([tag, index])[2] - def tx_lookup(self, tx_id: int, tag: TxContextFieldTag, index: int = 0) -> int: - return self.tables.tx_lookup([tx_id, tag, index])[3] + def tx_context_lookup(self, tx_id: int, field_tag: TxContextFieldTag) -> Union[int, RLC]: + return self.tables.tx_lookup([tx_id, field_tag, 0])[3] + + def tx_calldata_lookup(self, tx_id: int, index: int) -> int: + return self.tables.tx_lookup([tx_id, TxContextFieldTag.CallData, index])[3] def bytecode_lookup(self, bytecode_hash: int, index: int, is_code: int) -> int: return self.tables.bytecode_lookup([bytecode_hash, index, Tables._, is_code])[2] + def tx_gas_price(self, tx_id: int) -> int: + return self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice) + + def responsible_opcode_lookup(self, opcode: int): + self.fixed_lookup(FixedTableTag.ResponsibleOpcode, [self.curr.execution_state, opcode]) + def opcode_lookup(self, is_code: bool) -> int: index = self.curr.program_counter + self.program_counter_offset self.program_counter_offset += 1 @@ -257,7 +347,7 @@ def opcode_lookup_at(self, index: int, is_code: bool) -> int: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.opcode_source, index, is_code) + return self.bytecode_lookup(self.curr.code_source, index, is_code) def rw_lookup(self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None) -> Array8: if rw_counter is None: @@ -285,48 +375,71 @@ def state_write_with_reversion( inputs: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> Array8: assert tag.write_with_reversion() row = self.rw_lookup(RW.Write, tag, inputs) - rw_counter = rw_counter_end_of_reversion - self.curr.state_write_counter - self.state_write_counter_offset - self.state_write_counter_offset += 1 + if state_write_counter is None: + state_write_counter = self.curr.state_write_counter + self.state_write_counter_offset + self.state_write_counter_offset += 1 + + rw_counter = rw_counter_end_of_reversion - state_write_counter if not is_persistent: # Swap value and value_prev inputs = list(row[3:]) if tag == RWTableTag.TxAccessListAccount: inputs[2], inputs[3] = inputs[3], inputs[2] - elif tag == RWTableTag.TxAccessListStorageSlot: + elif tag == RWTableTag.TxAccessListAccountStorage: inputs[3], inputs[4] = inputs[4], inputs[3] elif tag == RWTableTag.Account: inputs[2], inputs[3] = inputs[3], inputs[2] elif tag == RWTableTag.AccountStorage: - inputs[3], inputs[4] = inputs[4], inputs[3] + inputs[2], inputs[3] = inputs[3], inputs[2] self.rw_lookup(RW.Write, tag, inputs, rw_counter=rw_counter) return row - def call_context_lookup(self, tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Union[int, None] = None) -> int: + def call_context_lookup( + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[int] = None + ) -> int: if call_id is None: call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, tag])[5] + return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag])[5] - def stack_pop(self) -> int: + def stack_pop(self) -> Union[int, RLC]: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 return self.stack_lookup(False, stack_pointer_offset) - def stack_push(self) -> int: + def stack_push(self) -> Union[int, RLC]: self.stack_pointer_offset -= 1 return self.stack_lookup(True, self.stack_pointer_offset) - def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> int: + def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[int, RLC]: stack_pointer = self.curr.stack_pointer + stack_pointer_offset return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer])[5] + def memory_write(self, memory_address: int, call_id: Optional[int] = None) -> int: + return self.memory_lookup(RW.Write, memory_address, call_id) + + def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = None) -> int: + if call_id is None: + call_id = self.curr.call_id + + return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[5] + + def tx_refund_read(self, tx_id) -> int: + row = self.rw_lookup(RW.Read, RWTableTag.TxRefund, [tx_id]) + return row[4] + + def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> int: + row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) + return row[5] + def account_write( self, account_address: int, @@ -345,20 +458,23 @@ def account_write_with_reversion( account_field_tag: AccountFieldTag, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> Tuple[int, int]: row = self.state_write_with_reversion( RWTableTag.Account, [account_address, account_field_tag], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[5], row[6] - def add_balance(self, account_address: int, values: Sequence[int]): + def add_balance(self, account_address: int, values: Sequence[int]) -> Tuple[int, int]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result) self.constrain_zero(carry) + return balance, balance_prev def add_balance_with_reversion( self, @@ -366,19 +482,22 @@ def add_balance_with_reversion( values: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - ): + state_write_counter: Optional[int] = None, + ) -> Tuple[int, int]: balance, balance_prev = self.account_write_with_reversion( - account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion + account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion, state_write_counter ) result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result) self.constrain_zero(carry) + return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[int]): + def sub_balance(self, account_address: int, values: Sequence[int]) -> Tuple[int, int]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result) self.constrain_zero(carry) + return balance, balance_prev def sub_balance_with_reversion( self, @@ -386,17 +505,15 @@ def sub_balance_with_reversion( values: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - ): + state_write_counter: Optional[int] = None, + ) -> Tuple[int, int]: balance, balance_prev = self.account_write_with_reversion( - account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion + account_address, AccountFieldTag.Balance, is_persistent, rw_counter_end_of_reversion, state_write_counter ) result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result) self.constrain_zero(carry) - - def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> Tuple[int, int]: - row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) - return row[5], row[6] + return balance, balance_prev def add_account_to_access_list( self, @@ -416,40 +533,45 @@ def add_account_to_access_list_with_reversion( account_address: int, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> bool: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccount, [tx_id, account_address, 1], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[5] - row[6] - def add_storage_slot_to_access_list( + def add_account_storage_to_access_list( self, tx_id: int, account_address: int, - storage_slot: int, + storage_key: int, ) -> bool: - row = self.state_write_with_reversion( - RWTableTag.TxAccessListAccount, - [tx_id, account_address, storage_slot, 1], + row = self.rw_lookup( + RW.Write, + RWTableTag.TxAccessListAccountStorage, + [tx_id, account_address, storage_key, 1], ) return row[6] - row[7] - def add_storage_slot_to_access_list_with_reversion( + def add_account_storage_to_access_list_with_reversion( self, tx_id: int, account_address: int, - storage_slot: int, + storage_key: int, is_persistent: bool, rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, ) -> bool: row = self.state_write_with_reversion( - RWTableTag.TxAccessListAccount, - [tx_id, account_address, storage_slot, 1], + RWTableTag.TxAccessListAccountStorage, + [tx_id, account_address, storage_key, 1], is_persistent, rw_counter_end_of_reversion, + state_write_counter, ) return row[6] - row[7] @@ -461,16 +583,88 @@ def transfer_with_gas_fee( gas_fee: int, is_persistent: bool, rw_counter_end_of_reversion: int, - ): - self.sub_balance_with_reversion( + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + sender_balance_pair = self.sub_balance_with_reversion( sender_address, [value, gas_fee], is_persistent, rw_counter_end_of_reversion, ) - self.add_balance_with_reversion( + receiver_balance_pair = self.add_balance_with_reversion( + receiver_address, + [value], + is_persistent, + rw_counter_end_of_reversion, + ) + return sender_balance_pair, receiver_balance_pair + + def transfer( + self, + sender_address: int, + receiver_address: int, + value: int, + is_persistent: bool, + rw_counter_end_of_reversion: int, + state_write_counter: Optional[int] = None, + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + sender_balance_pair = self.sub_balance_with_reversion( + sender_address, + [value], + is_persistent, + rw_counter_end_of_reversion, + state_write_counter, + ) + receiver_balance_pair = self.add_balance_with_reversion( receiver_address, [value], is_persistent, rw_counter_end_of_reversion, + None if state_write_counter is None else state_write_counter + 1, + ) + return sender_balance_pair, receiver_balance_pair + + def memory_offset_and_length_to_int(self, offset: RLC, length: RLC) -> Tuple[int, int]: + length = self.rlc_to_int_exact(length, N_BYTES_MEMORY_ADDRESS) + if self.is_zero(length): + return 0, 0 + + offset = self.rlc_to_int_exact(offset, N_BYTES_MEMORY_ADDRESS) + return offset, length + + def memory_gas_cost(self, memory_size: int) -> int: + quadratic_cost, _ = self.constant_divmod( + memory_size * memory_size, MEMORY_EXPANSION_QUAD_DENOMINATOR, N_BYTES_GAS ) + linear_cost = MEMORY_EXPANSION_LINEAR_COEFF * memory_size + return quadratic_cost + linear_cost + + def memory_expansion_constant_length(self, offset: int, length: int) -> Tuple[int, int]: + memory_size, _ = self.constant_divmod(length + offset + 31, 32, N_BYTES_MEMORY_SIZE) + + next_memory_size = self.max(self.curr.memory_size, memory_size, N_BYTES_MEMORY_SIZE) + + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) + memory_gas_cost_next = self.memory_gas_cost(next_memory_size) + memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost + + return next_memory_size, memory_expansion_gas_cost + + def memory_expansion_dynamic_length( + self, + cd_offset: int, + cd_length: int, + rd_offset: Optional[int] = None, + rd_length: Optional[int] = None, + ) -> Tuple[int, int]: + cd_memory_size, _ = self.constant_divmod(cd_offset + cd_length + 31, 32, N_BYTES_MEMORY_SIZE) + next_memory_size = self.max(self.curr.memory_size, cd_memory_size, N_BYTES_MEMORY_SIZE) + + if rd_offset is not None: + rd_memory_size, _ = self.constant_divmod(rd_offset + rd_length + 31, 32, N_BYTES_MEMORY_SIZE) + next_memory_size = self.max(next_memory_size, rd_memory_size, N_BYTES_MEMORY_SIZE) + + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) + memory_gas_cost_next = self.memory_gas_cost(next_memory_size) + memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost + + return next_memory_size, memory_expansion_gas_cost diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index 0564289dc..f680a243e 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -1,15 +1,6 @@ from typing import Sequence -from ..util.arithmetic import RLCStore -from .execution import ( - add, - begin_tx, - push, - jump, - jumpi, - coinbase, - caller, -) +from .execution import EXECUTION_STATE_IMPL from .execution_state import ExecutionState from .instruction import Instruction from .step import StepState @@ -17,49 +8,37 @@ def verify_steps( - rlc_store: RLCStore, + randomness: int, tables: Tables, steps: Sequence[StepState], begin_with_first_step: bool = False, - end_with_final_step: bool = False, + end_with_last_step: bool = False, ): - for idx in range(len(steps) - 1): + # TODO: Enforce general ExecutionState transition constraint + + for idx in range(len(steps) if end_with_last_step else len(steps) - 1): verify_step( - Instruction(rlc_store=rlc_store, tables=tables, curr=steps[idx], next=steps[idx + 1]), - begin_with_first_step and idx == 0, - end_with_final_step and idx == len(steps) - 2, + Instruction( + randomness=randomness, + tables=tables, + curr=steps[idx], + next=steps[idx + 1] if idx + 1 < len(steps) else None, + is_first_step=begin_with_first_step and idx == 0, + is_last_step=idx + 1 == len(steps), + ), ) def verify_step( instruction: Instruction, - is_first_step: bool = False, - is_final_step: bool = False, ): - if is_first_step: + if instruction.is_first_step: instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.BeginTx) - if instruction.curr.execution_state == ExecutionState.BeginTx: - begin_tx(instruction, is_first_step) - # Opcode's successful cases - elif instruction.curr.execution_state == ExecutionState.ADD: - add(instruction) - elif instruction.curr.execution_state == ExecutionState.PUSH: - push(instruction) - elif instruction.curr.execution_state == ExecutionState.JUMP: - jump(instruction) - elif instruction.curr.execution_state == ExecutionState.JUMPI: - jumpi(instruction) - elif instruction.curr.execution_state == ExecutionState.COINBASE: - coinbase(instruction) - elif instruction.curr.execution_state == ExecutionState.CALLER: - caller(instruction) - # Error cases + if instruction.curr.execution_state in EXECUTION_STATE_IMPL: + EXECUTION_STATE_IMPL[instruction.curr.execution_state](instruction) else: raise NotImplementedError - if is_final_step: - # Verify no malicious insertion - assert instruction.curr.rw_counter == len(instruction.tables.rw_table) - - # TODO: Verify final step has the tx_id identical to the amount in tx_table + if instruction.is_last_step: + instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.EndBlock) diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 362126dac..21c24aedd 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -149,6 +149,24 @@ class Opcode(IntEnum): def hex(self) -> str: return "{:02x}".format(self) + def bytes(self) -> bytes: + return bytes([self]) + + def is_push(self) -> bool: + return Opcode.PUSH1 <= self <= Opcode.PUSH32 + + def is_dup(self) -> bool: + return Opcode.DUP1 <= self <= Opcode.DUP16 + + def is_swap(self) -> bool: + return Opcode.SWAP1 <= self <= Opcode.SWAP16 + + def max_stack_pointer(self) -> int: + return OPCODE_INFO_MAP[self].max_stack_pointer + + def min_stack_pointer(self) -> int: + return OPCODE_INFO_MAP[self].min_stack_pointer + def constant_gas_cost(self) -> int: return OPCODE_INFO_MAP[self].constant_gas_cost @@ -339,15 +357,14 @@ def valid_opcodes() -> Sequence[Opcode]: def invalid_opcodes() -> Sequence[int]: - return [opcode for opcode in range(256) if opcode not in OPCODE_INFO_MAP] + return [opcode for opcode in range(256) if opcode not in valid_opcodes()] def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: pairs = [] for opcode in valid_opcodes(): - opcode_info = OPCODE_INFO_MAP[opcode] - if opcode_info.min_stack_pointer > 0: - for stack_pointer in range(opcode_info.min_stack_pointer): + if opcode.min_stack_pointer() > 0: + for stack_pointer in range(opcode.min_stack_pointer()): pairs.append((opcode, stack_pointer)) return pairs @@ -355,9 +372,8 @@ def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: def stack_underflow_pairs() -> Sequence[Tuple[int, int]]: pairs = [] for opcode in valid_opcodes(): - opcode_info = OPCODE_INFO_MAP[opcode] - if opcode_info.max_stack_pointer < 1024: - for stack_pointer in range(opcode_info.max_stack_pointer, 1024): + if opcode.max_stack_pointer() < 1024: + for stack_pointer in range(opcode.max_stack_pointer(), 1024): pairs.append((opcode, stack_pointer + 1)) return pairs diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index dd89ef1ab..ea5ccf8cf 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -6,7 +6,7 @@ class StepState: Step state EVM circuit tracks step by step and used to ensure the execution trace is verified continuously and chronologically. It includes fields that are used from beginning to end like is_root, - is_create and opcode_source. + is_create and code_source. It also includes call's mutable states which change almost every step like program_counter and stack_pointer. """ @@ -18,15 +18,15 @@ class StepState: # The following 3 fields decide the opcode source. There are 2 possible # cases: # 1. Root creation call (is_root and is_create) - # It was planned to set the opcode_source to tx_id, then lookup tx_table's + # It was planned to set the code_source to tx_id, then lookup tx_table's # CallData field directly, but is still yet to be determined. # See the issue https://github.com/appliedzkp/zkevm-specs/issues/73 for # further discussion. # 2. Deployed contract interaction or internal creation call - # We set opcode_source to bytecode_hash and lookup bytecode_table. + # We set code_source to bytecode_hash and lookup bytecode_table. is_root: bool is_create: bool - opcode_source: int + code_source: int # The following fields change almost every step. program_counter: int @@ -45,10 +45,10 @@ def __init__( self, execution_state: ExecutionState, rw_counter: int, - call_id: int, + call_id: int = 0, is_root: bool = False, is_create: bool = False, - opcode_source: int = 0, + code_source: int = 0, program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, @@ -63,7 +63,7 @@ def __init__( self.call_id = call_id self.is_root = is_root self.is_create = is_create - self.opcode_source = opcode_source + self.code_source = code_source self.program_counter = program_counter self.stack_pointer = stack_pointer self.gas_left = gas_left diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index a970b09d2..116605a2f 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,5 +1,7 @@ +from __future__ import annotations from typing import Sequence, Set, Tuple from enum import IntEnum, auto +from itertools import chain, product from ..util import Array3, Array4, Array8 from .execution_state import ExecutionState @@ -38,6 +40,62 @@ class FixedTableTag(IntEnum): StackOverflow = auto() # opcode, stack_pointer, 0 StackUnderflow = auto() # opcode, stack_pointer, 0 + def table_assignments(self) -> Sequence[Array4]: + if self == FixedTableTag.Range16: + return [(self, i, 0, 0) for i in range(16)] + elif self == FixedTableTag.Range32: + return [(self, i, 0, 0) for i in range(32)] + elif self == FixedTableTag.Range64: + return [(self, i, 0, 0) for i in range(64)] + elif self == FixedTableTag.Range256: + return [(self, i, 0, 0) for i in range(256)] + elif self == FixedTableTag.Range512: + return [(self, i, 0, 0) for i in range(512)] + elif self == FixedTableTag.Range1024: + return [(self, i, 0, 0) for i in range(1024)] + elif self == FixedTableTag.SignByte: + return [(self, i, (i >> 7) * 0xFF, 0) for i in range(256)] + elif self == FixedTableTag.BitwiseAnd: + return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.BitwiseOr: + return [(self, lhs, rhs, lhs | rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.BitwiseXor: + return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] + elif self == FixedTableTag.ResponsibleOpcode: + return [ + (self, execution_state, opcode, 0) + for execution_state in list(ExecutionState) + for opcode in execution_state.responsible_opcode() + ] + elif self == FixedTableTag.InvalidOpcode: + return [(self, opcode, 0, 0) for opcode in invalid_opcodes()] + elif self == FixedTableTag.StateWriteOpcode: + return [(self, opcode, 0, 0) for opcode in state_write_opcodes()] + elif self == FixedTableTag.StackOverflow: + return [(self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_underflow_pairs()] + elif self == FixedTableTag.StackUnderflow: + return [(self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs()] + else: + ValueError("Unreacheable") + + def range_table_tag(range: int) -> FixedTableTag: + if range == 16: + return FixedTableTag.Range16 + elif range == 32: + return FixedTableTag.Range32 + elif range == 64: + return FixedTableTag.Range64 + elif range == 256: + return FixedTableTag.Range256 + elif range == 512: + return FixedTableTag.Range512 + elif range == 1024: + return FixedTableTag.Range1024 + else: + raise ValueError( + f"Range {range} lookup is not supported yet, please add a new variant Range{range} in FixedTableTag with proper table assignments" + ) + class BlockContextFieldTag(IntEnum): """ @@ -49,11 +107,11 @@ class BlockContextFieldTag(IntEnum): Coinbase = auto() GasLimit = auto() - BlockNumber = auto() - Time = auto() + Number = auto() + Timestamp = auto() Difficulty = auto() BaseFee = auto() - BlockHash = auto() + HistoryHash = auto() class TxContextFieldTag(IntEnum): @@ -89,7 +147,7 @@ class RWTableTag(IntEnum): """ TxAccessListAccount = auto() - TxAccessListStorageSlot = auto() + TxAccessListAccountStorage = auto() TxRefund = auto() Account = auto() @@ -105,7 +163,7 @@ class RWTableTag(IntEnum): def write_with_reversion(self) -> bool: return self in [ RWTableTag.TxAccessListAccount, - RWTableTag.TxAccessListStorageSlot, + RWTableTag.TxAccessListAccountStorage, RWTableTag.Account, RWTableTag.AccountStorage, ] @@ -138,8 +196,8 @@ class CallContextFieldTag(IntEnum): # It's not like transaction or bytecode that require specifically friendly # layout for verification, so maintaining the consistency directly in # RWTable seems more intuitive than creating another table for it. - RWCounterEndOfReversion = auto() # to know at which point in the future we should revert - CallerCallId = auto() # to know caller's id + RwCounterEndOfReversion = auto() # to know at which point in the future we should revert + CallerId = auto() # to know caller's id TxId = auto() # to know tx's id Depth = auto() # to know if call too deep CallerAddress = auto() @@ -149,10 +207,17 @@ class CallContextFieldTag(IntEnum): ReturnDataOffset = auto() # for callee to set return_data to caller's memeory ReturnDataLength = auto() Value = auto() - Result = auto() # to peek result in the future + IsSuccess = auto() # to peek result in the future IsPersistent = auto() # to know if current call is within reverted call or not IsStatic = auto() # to know if state modification is within static call or not + # The following are read-only data inside a call like previous section for + # opcode RETURNDATASIZE and RETURNDATACOPY, except they will be updated when + # end of callee execution. + LastCalleeId = auto() + LastCalleeReturnDataOffset = auto() + LastCalleeReturnDataLength = auto() + # The following are used by caller to save its own CallState when it's # going to dive into another call, and will be read out to restore caller's # CallState in the end by callee. @@ -161,7 +226,7 @@ class CallContextFieldTag(IntEnum): # different kinds of RWTableTag. IsRoot = auto() IsCreate = auto() - OpcodeSource = auto() + CodeSource = auto() ProgramCounter = auto() StackPointer = auto() GasLeft = auto() @@ -193,44 +258,18 @@ class Tables: # - value1 # - value2 # - value3 - fixed_table: Set[Array4] = set( - [(FixedTableTag.Range16, i, 0, 0) for i in range(16)] - + [(FixedTableTag.Range32, i, 0, 0) for i in range(32)] - + [(FixedTableTag.Range64, i, 0, 0) for i in range(64)] - + [(FixedTableTag.Range256, i, 0, 0) for i in range(256)] - + [(FixedTableTag.Range512, i, 0, 0) for i in range(512)] - + [(FixedTableTag.Range1024, i, 0, 0) for i in range(1024)] - + [(FixedTableTag.SignByte, i, (i & 1) * 0xFF, 0) for i in range(256)] - + [(FixedTableTag.BitwiseAnd, lhs, rhs, lhs & rhs) for lhs in range(256) for rhs in range(256)] - + [(FixedTableTag.BitwiseOr, lhs, rhs, lhs | rhs) for lhs in range(256) for rhs in range(256)] - + [(FixedTableTag.BitwiseXor, lhs, rhs, lhs ^ rhs) for lhs in range(256) for rhs in range(256)] - + [ - (FixedTableTag.ResponsibleOpcode, execution_state, opcode, 0) - for execution_state in list(ExecutionState) - for opcode in execution_state.responsible_opcode() - ] - + [(FixedTableTag.InvalidOpcode, opcode, 0, 0) for opcode in invalid_opcodes()] - + [(FixedTableTag.StateWriteOpcode, opcode, 0, 0) for opcode in state_write_opcodes()] - + [ - (FixedTableTag.StackUnderflow, opcode, stack_pointer, 0) - for (opcode, stack_pointer) in stack_underflow_pairs() - ] - + [ - (FixedTableTag.StackOverflow, opcode, stack_pointer, 0) - for (opcode, stack_pointer) in stack_overflow_pairs() - ] - ) + fixed_table: Set[Array4] = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) # Each row in BlockTable contains: # - tag - # - block_number_or_zero (meaningful only for BlockHash, will be zero for other tags) + # - block_number_or_zero (meaningful only for HistoryHash, will be zero for other tags) # - value block_table: Set[Array3] # Each row in TxTable contains: # - tx_id # - tag - # - index_or_zero (meaningful only for CallData, will be zero for other tags) + # - call_data_index_or_zero (meaningful only for CallData, will be zero for other tags) # - value tx_table: Set[Array4] diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index f3bce88fa..a07ed8241 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,4 +1,5 @@ -from typing import Iterator, Optional, Sequence, Union +from __future__ import annotations +from typing import Any, Dict, Iterator, NewType, Optional, Sequence from functools import reduce from itertools import chain @@ -8,58 +9,64 @@ U256, Array3, Array4, - RLCStore, + RLC, keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, ) from .table import BlockContextFieldTag, TxContextFieldTag -from .opcode import get_push_size +from .opcode import get_push_size, Opcode class Block: coinbase: U160 + + # Gas needs a lot arithmetic operation or comparison in EVM circuit, so we + # assume gas limit in the near futuer will not exceed U64, to reduce the + # implementation complexity. gas_limit: U64 - block_number: U256 - time: U64 + + # For other fields, we follow the size defined in yellow paper for now. + number: U256 + timestamp: U256 difficulty: U256 base_fee: U256 - # history_hashes contains most recent 256 block hashes in history, where - # the lastest one is at history_hashes[-1]. + # It contains most recent 256 block hashes in history, where the lastest + # one is at history_hashes[-1]. history_hashes: Sequence[U256] def __init__( self, coinbase: U160 = 0x10, gas_limit: U64 = int(15e6), - block_number: U256 = 0, - time: U64 = 0, + number: U256 = 0, + timestamp: U256 = 0, difficulty: U256 = 0, base_fee: U256 = int(1e9), history_hashes: Sequence[U256] = [], ) -> None: - assert len(history_hashes) <= min(256, block_number) + assert len(history_hashes) <= min(256, number) self.coinbase = coinbase self.gas_limit = gas_limit - self.block_number = block_number - self.time = time + self.number = number + self.timestamp = timestamp self.difficulty = difficulty self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, rlc_store: RLCStore) -> Sequence[Array3]: + def table_assignments(self, randomness: int) -> Sequence[Array3]: return [ (BlockContextFieldTag.Coinbase, 0, self.coinbase), (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.BlockNumber, 0, rlc_store.to_rlc(self.block_number, 32)), - (BlockContextFieldTag.Time, 0, rlc_store.to_rlc(self.time, 32)), - (BlockContextFieldTag.Difficulty, 0, rlc_store.to_rlc(self.difficulty, 32)), - (BlockContextFieldTag.BaseFee, 0, rlc_store.to_rlc(self.base_fee, 32)), + (BlockContextFieldTag.Number, 0, RLC(self.number, randomness)), + (BlockContextFieldTag.Timestamp, 0, RLC(self.timestamp, randomness)), + (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), + (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), ] + [ - (BlockContextFieldTag.BlockHash, self.block_number - idx - 1, rlc_store.to_rlc(block_hash, 32)) - for idx, block_hash in enumerate(reversed(self.history_hashes)) + (BlockContextFieldTag.HistoryHash, self.number - idx - 1, RLC(history_hash, randomness)) + for idx, history_hash in enumerate(reversed(self.history_hashes)) ] @@ -93,63 +100,106 @@ def __init__( self.value = value self.call_data = call_data - def table_assignments(self, rlc_store: RLCStore) -> Iterator[Array4]: - def call_data_gas_cost_per_byte(byte: int): - return GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE if byte is 0 else GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE + def call_data_gas_cost(self) -> int: + return reduce( + lambda acc, byte: ( + acc + (GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE if byte is 0 else GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE) + ), + self.call_data, + 0, + ) - call_data_gas_cost = reduce(lambda acc, byte: acc + call_data_gas_cost_per_byte(byte), self.call_data, 0) + def table_assignments(self, randomness: int) -> Iterator[Array4]: return chain( [ (self.id, TxContextFieldTag.Nonce, 0, self.nonce), (self.id, TxContextFieldTag.Gas, 0, self.gas), - (self.id, TxContextFieldTag.GasPrice, 0, rlc_store.to_rlc(self.gas_price, 32)), + (self.id, TxContextFieldTag.GasPrice, 0, RLC(self.gas_price, randomness)), (self.id, TxContextFieldTag.CallerAddress, 0, self.caller_address), (self.id, TxContextFieldTag.CalleeAddress, 0, self.callee_address), (self.id, TxContextFieldTag.IsCreate, 0, self.callee_address is None), - (self.id, TxContextFieldTag.Value, 0, rlc_store.to_rlc(self.value, 32)), + (self.id, TxContextFieldTag.Value, 0, RLC(self.value, randomness)), (self.id, TxContextFieldTag.CallDataLength, 0, len(self.call_data)), - (self.id, TxContextFieldTag.CallDataGasCost, 0, call_data_gas_cost), + (self.id, TxContextFieldTag.CallDataGasCost, 0, self.call_data_gas_cost()), ], map(lambda item: (self.id, TxContextFieldTag.CallData, item[0], item[1]), enumerate(self.call_data)), ) class Bytecode: - hash: U256 - bytes: bytes - - def __init__( - self, - str_or_bytes: Union[str, bytes], - ): - if type(str_or_bytes) is str: - str_or_bytes = bytes.fromhex(str_or_bytes) - - self.hash = int.from_bytes(keccak256(str_or_bytes), "little") - self.bytes = str_or_bytes - - def table_assignments(self, rlc_store: RLCStore) -> Iterator[Array4]: + code: bytearray + + def __init__(self, code: Optional[bytearray] = None) -> None: + self.code = bytearray() if code is None else code + + def __getattr__(self, name: str): + def method(*args) -> Bytecode: + try: + opcode = Opcode[name.removesuffix("_").upper()] + except KeyError: + raise ValueError(f"Invalid opcode {name}") + + if opcode.is_push(): + assert len(args) == 1 + self.push(args[0], opcode - Opcode.PUSH1 + 1) + elif opcode.is_dup() or opcode.is_swap(): + assert len(args) == 0 + self.code.append(opcode) + else: + assert len(args) <= 1024 - opcode.max_stack_pointer() + for arg in reversed(args): + self.push(arg, 32) + self.code.append(opcode) + + return self + + return method + + def push(self, value: Any, n_bytes: int = 32) -> Bytecode: + if isinstance(value, int): + value = value.to_bytes(n_bytes, "big") + elif isinstance(value, str): + value = bytes.fromhex(value.lower().removeprefix("0x")) + elif isinstance(value, RLC): + value = value.be_bytes() + elif isinstance(value, bytes) or isinstance(value, bytearray): + ... + else: + raise NotImplementedError(f"Value of type {type(value)} is not yet supported") + + assert 0 < len(value) <= n_bytes, ValueError("Too many bytes as data portion of PUSH*") + + opcode = Opcode.PUSH1 + n_bytes - 1 + self.code.append(opcode) + self.code.extend(value.rjust(n_bytes, bytes(1))) + + return self + + def hash(self) -> int: + return int.from_bytes(keccak256(self.code), "big") + + def table_assignments(self, randomness: int) -> Iterator[Array4]: class BytecodeIterator: idx: int push_data_left: int - hash: int - bytes: bytes + hash: RLC + code: bytes - def __init__(self, hash: int, bytes: bytes): + def __init__(self, hash: RLC, code: bytes): self.idx = 0 self.push_data_left = 0 self.hash = hash - self.bytes = bytes + self.code = code def __iter__(self): return self def __next__(self): - if self.idx == len(self.bytes): + if self.idx == len(self.code): raise StopIteration idx = self.idx - byte = self.bytes[idx] + byte = self.code[idx] is_code = self.push_data_left == 0 self.push_data_left = get_push_size(byte) if is_code else self.push_data_left - 1 @@ -158,4 +208,35 @@ def __next__(self): return (self.hash, idx, byte, is_code) - return BytecodeIterator(rlc_store.to_rlc(self.hash, 32), self.bytes) + return BytecodeIterator(RLC(self.hash(), randomness), self.code) + + +Storage = NewType("Storage", Dict[U256, U256]) + + +class Account: + address: U160 + nonce: U256 + balance: U256 + code: Bytecode + storage: Storage + + def __init__( + self, + address: U160 = 0, + nonce: U256 = 0, + balance: U256 = 0, + code: Optional[Bytecode] = None, + storage: Optional[Storage] = None, + ) -> None: + self.address = address + self.nonce = nonce + self.balance = balance + self.code = Bytecode() if code is None else code + self.storage = dict() if storage is None else storage + + def code_hash(self) -> U256: + return self.code.hash() + + def storage_trie_hash(self) -> U256: + raise NotImplementedError("Trie has not been implemented") diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 2bfab45de..3c8552c4b 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -8,14 +8,14 @@ from .typing import * -def hex_to_word(hex: str) -> bytes: - return bytes.fromhex(hex.removeprefix("0x").zfill(64)) - - def rand_range(stop: Union[int, float] = 2 ** 256) -> int: return randrange(0, int(stop)) +def rand_fp() -> int: + return rand_range(FP_MODULUS) + + def rand_address() -> U160: return rand_range(2 ** 160) diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index 8f70b6265..068ef6d32 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -1,9 +1,5 @@ -from typing import Dict, Sequence, Tuple, Union -from Crypto.Random import get_random_bytes -from Crypto.Random.random import randrange - -from .param import MAX_N_BYTES - +from __future__ import annotations +from typing import Sequence, Union # BN254 scalar field size FP_MODULUS = 21888242871839275222246405745257275088548364400416034343698204186575808495617 @@ -21,70 +17,43 @@ def fp_inv(value: int) -> int: return pow(value, -1, FP_MODULUS) -def le_to_int(bytes: Sequence[int]) -> int: - assert len(bytes) <= MAX_N_BYTES, "too many bytes to composite an integer in field" - return linear_combine(bytes, 256) +def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: int) -> int: + com = 0 + for byte in reversed(le_bytes): + assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" + com = fp_add(fp_mul(com, factor), byte) + return com + +class RLC: + le_bytes: bytes + value: int -def linear_combine(bytes: Sequence[int], r: int) -> int: - ret = 0 - for byte in reversed(bytes): - assert 0 <= byte < 256, "bytes for linear combination should be already checked in range" - ret = fp_add(fp_mul(ret, r), byte) - return ret + def __init__(self, int_or_bytes: Union[int, bytes], randomness: int, n_bytes: int = 32) -> None: + if isinstance(int_or_bytes, int): + assert 0 <= int_or_bytes < 256 ** n_bytes, f"Value {int_or_bytes} too large to fit {n_bytes} bytes" + self.le_bytes = int_or_bytes.to_bytes(n_bytes, "little") + elif isinstance(int_or_bytes, bytes): + assert len(int_or_bytes) <= n_bytes, f"Expected bytes with length less or equal than {n_bytes}" + self.le_bytes = int_or_bytes.ljust(n_bytes, b"\x00") + else: + raise TypeError(f"Expected an int or bytes, but got object of type {type(int_or_bytes)}") + self.value = fp_linear_combine(self.le_bytes, randomness) -class RLCStore: - randomness: int - rlc_to_bytes: Dict[int, bytes] = dict() + def __eq__(self, rhs: Union[int, RLC]): + if isinstance(rhs, int): + return self.value == rhs + if isinstance(rhs, RLC): + return self.value == rhs.value + else: + raise TypeError(f"Expected a RLC, but got object of type {type(rhs)}") - def __init__(self, randomness: int = randrange(0, FP_MODULUS)) -> None: - self.randomness = randomness - for byte in range(256): - self.to_rlc([byte]) + def __hash__(self) -> int: + return self.value - def to_rlc(self, seq_or_int: Union[Sequence[int], int], n_bytes: int = 0) -> int: - seq = seq_or_int - if type(seq_or_int) == int: - seq = seq_or_int.to_bytes(n_bytes, "little") - rlc = linear_combine(seq, self.randomness) + def __repr__(self) -> str: + return int.from_bytes(self.le_bytes, "little").__repr__() - if rlc in self.rlc_to_bytes: - maxlen = max(len(self.rlc_to_bytes[rlc]), len(seq)) - assert self.rlc_to_bytes[rlc].rjust(maxlen, b"\x00") == bytes(seq).rjust( - maxlen, b"\x00" - ), f"Random lienar combination collision on {self.rlc_to_bytes[rlc]} and {bytes(seq)} with randomness {self.randomness}" - else: - self.rlc_to_bytes[rlc] = bytes(seq) - - return rlc - - def to_bytes(self, rlc: int) -> bytes: - return self.rlc_to_bytes[rlc] - - def rand(self, n_bytes: int = 32) -> Tuple[int, bytes]: - bytes = get_random_bytes(n_bytes) - return self.to_rlc(bytes), bytes - - def add(self, lhs: int, rhs: int, modulus: int = 2 ** 256) -> Tuple[int, bytes, bool]: - lhs_bytes = self.to_bytes(lhs) - rhs_bytes = self.to_bytes(rhs) - carry, result = divmod( - int.from_bytes(lhs_bytes, "little") + int.from_bytes(rhs_bytes, "little"), - modulus, - ) - result_bytes = result.to_bytes(32, "little") - return self.to_rlc(result_bytes), result_bytes, carry > 0 - - def sub(self, lhs: int, rhs: int, modulus: int = 2 ** 256) -> Tuple[int, bytes, bool]: - lhs_bytes = self.to_bytes(lhs) - rhs_bytes = self.to_bytes(rhs) - borrow, result = divmod( - int.from_bytes(lhs_bytes, "little") - int.from_bytes(rhs_bytes, "little"), - modulus, - ) - assert ( - result + int.from_bytes(rhs_bytes, "little") == int.from_bytes(lhs_bytes, "little") + (borrow < 0) * modulus - ) - result_bytes = result.to_bytes(32, "little") - return self.to_rlc(result_bytes), result_bytes, borrow < 0 + def be_bytes(self) -> bytes: + return bytes(reversed(self.le_bytes)) diff --git a/src/zkevm_specs/util/hash.py b/src/zkevm_specs/util/hash.py index f196e3d6f..99d750cff 100644 --- a/src/zkevm_specs/util/hash.py +++ b/src/zkevm_specs/util/hash.py @@ -1,13 +1,15 @@ from typing import Union from Crypto.Hash import keccak +from .typing import U256 -def keccak256(data: Union[str, bytes]) -> bytes: + +def keccak256(data: Union[str, bytes, bytearray]) -> bytes: if type(data) == str: data = bytes.fromhex(data) return keccak.new(digest_bits=256).update(data).digest() -EMPTY_HASH = keccak256("") -EMPTY_CODE_HASH = EMPTY_HASH -EMPTY_TRIE_HASH = keccak256("80") +EMPTY_HASH: U256 = int.from_bytes(keccak256(""), "big") +EMPTY_CODE_HASH: U256 = EMPTY_HASH +EMPTY_TRIE_HASH: U256 = int.from_bytes(keccak256("80"), "big") diff --git a/src/zkevm_specs/util/param.py b/src/zkevm_specs/util/param.py index 71c0192d8..41e2a6a39 100644 --- a/src/zkevm_specs/util/param.py +++ b/src/zkevm_specs/util/param.py @@ -1,9 +1,40 @@ # Maximun number of bytes with composition value that doesn't wrap around the field MAX_N_BYTES = 31 +# Number of bytes of account address +N_BYTES_ACCOUNT_ADDRESS = 20 +# Number of bytes of memory address +N_BYTES_MEMORY_ADDRESS = 5 +# Number of bytes of memory size (in word) +N_BYTES_MEMORY_SIZE = 4 # Number of bytes of gas N_BYTES_GAS = 8 +# Number of bytes of program counter +N_BYTES_PROGRAM_COUNTER = 8 +# Gas cost of non-creation transaction +GAS_COST_TX = 21000 +# Gas cost of creation transaction +GAS_COST_CREATION_TX = 53000 # Gas cost of transaction call_data per non-zero byte GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE = 16 # Gas cost of transaction call_data per zero byte GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE = 4 +# Gas cost of accessing account or storage slot +GAS_COST_WARM_ACCESS = 100 +# Extra gas cost of not-yet-accessed account +EXTRA_GAS_COST_ACCOUNT_COLD_ACCESS = 2500 +# Extra gas cost of not-yet-accessed storage slot +EXTRA_GAS_COST_STORAGE_SLOT_COLD_ACCESS = 2000 +# Gas cost of calling with non-zero value +GAS_COST_CALL_WITH_VALUE = 9000 +# Gas cost of calling empty account +GAS_COST_CALL_EMPTY_ACCOUNT = 25000 +# Gas stipend given if call with non-zero value +GAS_STIPEND_CALL_WITH_VALUE = 2300 + +# Quotient for max refund of gas used +MAX_REFUND_QUOTIENT_OF_GAS_USED = 5 +# Denominator of quadratic part of memory expansion gas cost +MEMORY_EXPANSION_QUAD_DENOMINATOR = 512 +# Coefficient of linear part of memory expansion gas cost +MEMORY_EXPANSION_LINEAR_COEFF = 3 diff --git a/tests/evm/test_add.py b/tests/evm/test_add.py index 8ba0371b8..d54a97ddf 100644 --- a/tests/evm/test_add.py +++ b/tests/evm/test_add.py @@ -12,37 +12,32 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, rand_word, RLC TESTING_DATA = ( - (Opcode.ADD, hex_to_word("030201"), hex_to_word("060504"), hex_to_word("090705")), - (Opcode.SUB, hex_to_word("090705"), hex_to_word("060504"), hex_to_word("030201")), - (Opcode.ADD, rand_bytes(), rand_bytes(), None), - (Opcode.SUB, rand_bytes(), rand_bytes(), None), + (Opcode.ADD, 0x030201, 0x060504, 0x090705), + (Opcode.SUB, 0x090705, 0x060504, 0x030201), + (Opcode.ADD, rand_word(), rand_word(), None), + (Opcode.SUB, rand_word(), rand_word(), None), ) -@pytest.mark.parametrize("opcode, a_bytes, b_bytes, c_bytes", TESTING_DATA) -def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[bytes]): - rlc_store = RLCStore() +@pytest.mark.parametrize("opcode, a, b, c", TESTING_DATA) +def test_add(opcode: Opcode, a: int, b: int, c: Optional[int]): + randomness = rand_fp() - a = rlc_store.to_rlc(a_bytes) - b = rlc_store.to_rlc(b_bytes) - c = ( - rlc_store.to_rlc(c_bytes) - if c_bytes is not None - else (rlc_store.add(a, b) if opcode == Opcode.ADD else rlc_store.sub(a, b))[0] - ) + c = RLC(c, randomness) if c is not None else RLC((a + b if opcode == Opcode.ADD else a - b) % 2 ** 256, randomness) + a = RLC(a, randomness) + b = RLC(b, randomness) - block = Block() - bytecode = Bytecode(f"7f{b_bytes.hex()}7f{a_bytes.hex()}{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().add(a, b) if opcode == Opcode.ADD else Bytecode().sub(a, b) + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1022, a, 0, 0), @@ -53,7 +48,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -62,7 +57,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=66, stack_pointer=1022, gas_left=3, @@ -73,7 +68,7 @@ def test_add(opcode: Opcode, a_bytes: bytes, b_bytes: bytes, c_bytes: Optional[b call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=67, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_begin_tx.py b/tests/evm/test_begin_tx.py index 9908a97a6..ac199750a 100644 --- a/tests/evm/test_begin_tx.py +++ b/tests/evm/test_begin_tx.py @@ -11,61 +11,100 @@ CallContextFieldTag, Block, Transaction, + Account, Bytecode, ) -from zkevm_specs.util import RLCStore, rand_address, rand_range +from zkevm_specs.util import rand_fp, rand_address, rand_range, RLC +from zkevm_specs.util.hash import EMPTY_CODE_HASH + +RETURN_BYTECODE = Bytecode().return_(0, 0) +REVERT_BYTECODE = Bytecode().revert(0, 0) + +CALLEE_ADDRESS = 0xFF +CALLEE_WITH_NOTHING = Account(address=CALLEE_ADDRESS) +CALLEE_WITH_RETURN_BYTECODE = Account(address=CALLEE_ADDRESS, code=RETURN_BYTECODE) +CALLEE_WITH_REVERT_BYTECODE = Account(address=CALLEE_ADDRESS, code=REVERT_BYTECODE) TESTING_DATA = ( - # Transfer 1 ether, successfully - (Transaction(caller_address=0xFE, callee_address=0xFF, value=int(1e18)), True), - # Transfer 1 ether, tx reverts - (Transaction(caller_address=0xFE, callee_address=0xFF, value=int(1e18)), False), + # Transfer 1 ether to EOA, successfully + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_NOTHING, + True, + ), + # Transfer 1 ether to contract, successfully + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), + # Transfer 1 ether to contract, tx reverts + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, value=int(1e18)), + CALLEE_WITH_REVERT_BYTECODE, + False, + ), # Transfer random ether, successfully - (Transaction(caller_address=rand_address(), callee_address=rand_address(), value=rand_range(1e20)), True), + ( + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, value=rand_range(1e20)), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), # Transfer nothing with random gas_price, successfully ( - Transaction(caller_address=rand_address(), callee_address=rand_address(), gas_price=rand_range(42857142857143)), + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, gas_price=rand_range(42857142857143)), + CALLEE_WITH_RETURN_BYTECODE, True, ), # Transfer random ether, tx reverts - (Transaction(caller_address=rand_address(), callee_address=rand_address(), value=rand_range(1e20)), False), + ( + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, value=rand_range(1e20)), + CALLEE_WITH_REVERT_BYTECODE, + False, + ), # Transfer nothing with random gas_price, tx reverts ( - Transaction(caller_address=rand_address(), callee_address=rand_address(), gas_price=rand_range(42857142857143)), + Transaction(caller_address=rand_address(), callee_address=CALLEE_ADDRESS, gas_price=rand_range(42857142857143)), + CALLEE_WITH_REVERT_BYTECODE, False, ), # Transfer nothing with some calldata - (Transaction(caller_address=0xFE, callee_address=0xFF, gas=21080, call_data=bytes([1, 2, 3, 4, 0, 0, 0, 0])), True), + ( + Transaction( + caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=21080, call_data=bytes([1, 2, 3, 4, 0, 0, 0, 0]) + ), + CALLEE_WITH_RETURN_BYTECODE, + True, + ), ) -@pytest.mark.parametrize("tx, result", TESTING_DATA) -def test_begin_tx(tx: Transaction, result: bool): - rlc_store = RLCStore() +@pytest.mark.parametrize("tx, callee, result", TESTING_DATA) +def test_begin_tx(tx: Transaction, callee: Account, result: bool): + randomness = rand_fp() - block = Block() - caller_balance_prev = rlc_store.to_rlc(int(1e20), 32) - callee_balance_prev = rlc_store.to_rlc(0, 32) - caller_balance = rlc_store.to_rlc(int(1e20) - (tx.value + tx.gas * tx.gas_price), 32) - callee_balance = rlc_store.to_rlc(tx.value, 32) + rw_counter_end_of_reversion = 23 + caller_balance_prev = int(1e20) + callee_balance_prev = callee.balance + caller_balance = caller_balance_prev - (tx.value + tx.gas * tx.gas_price) + callee_balance = callee_balance_prev + tx.value - bytecode = Bytecode("00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode_hash = RLC(callee.code_hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), - tx_table=set(tx.table_assignments(rlc_store)), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(callee.code.table_assignments(randomness)), rw_table=set( [ - (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, 1, 0, 0), + (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), ( 2, RW.Read, RWTableTag.CallContext, 1, - CallContextFieldTag.RWCounterEndOfReversion, - 0 if result else 20, + CallContextFieldTag.RwCounterEndOfReversion, + 0 if result else rw_counter_end_of_reversion, 0, 0, ), @@ -79,8 +118,8 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, - caller_balance, - caller_balance_prev, + RLC(caller_balance, randomness), + RLC(caller_balance_prev, randomness), 0, ), ( @@ -89,8 +128,8 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.Account, tx.callee_address, AccountFieldTag.Balance, - callee_balance, - callee_balance_prev, + RLC(callee_balance, randomness), + RLC(callee_balance_prev, randomness), 0, ), ( @@ -114,7 +153,7 @@ def test_begin_tx(tx: Transaction, result: bool): RWTableTag.CallContext, 1, CallContextFieldTag.Value, - rlc_store.to_rlc(tx.value, 32), + RLC(tx.value, randomness), 0, 0, ), @@ -125,23 +164,23 @@ def test_begin_tx(tx: Transaction, result: bool): if result else [ ( - 19, + rw_counter_end_of_reversion - 1, RW.Write, RWTableTag.Account, tx.callee_address, AccountFieldTag.Balance, - callee_balance_prev, - callee_balance, + RLC(callee_balance_prev, randomness), + RLC(callee_balance, randomness), 0, ), ( - 20, + rw_counter_end_of_reversion, RW.Write, RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, - caller_balance_prev, - caller_balance, + RLC(caller_balance_prev, randomness), + RLC(caller_balance, randomness), 0, ), ] @@ -150,21 +189,20 @@ def test_begin_tx(tx: Transaction, result: bool): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( execution_state=ExecutionState.BeginTx, rw_counter=1, - call_id=1, ), StepState( - execution_state=ExecutionState.STOP if result else ExecutionState.REVERT, + execution_state=ExecutionState.EndTx if callee.code_hash() == EMPTY_CODE_HASH else ExecutionState.PUSH, rw_counter=17, call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=0, diff --git a/tests/evm/test_caller.py b/tests/evm/test_caller.py index 9668f3af0..0e5c31ef9 100644 --- a/tests/evm/test_caller.py +++ b/tests/evm/test_caller.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,33 @@ CallContextFieldTag, Bytecode, ) -from zkevm_specs.util import RLCStore, U160 +from zkevm_specs.util import rand_address, rand_fp, RLC, U160 -TESTING_DATA = ((Opcode.CALLER, 0x030201),) +TESTING_DATA = (0x030201, rand_address()) -@pytest.mark.parametrize("opcode, address", TESTING_DATA) -def test_caller(opcode: Opcode, address: U160): - rlc_store = RLCStore() +@pytest.mark.parametrize("caller", TESTING_DATA) +def test_caller(caller: U160): + randomness = rand_fp() - caller_rlc = rlc_store.to_rlc(address.to_bytes(20, "little")) + bytecode = Bytecode().caller() + bytecode_hash = RLC(bytecode.hash(), randomness) - bytecode = Bytecode(f"{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) tables = Tables( block_table=set(), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, caller_rlc, 0, 0), - (10, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, address, 0, 0), + (9, RW.Write, RWTableTag.Stack, 1, 1023, RLC(caller, randomness, 20), 0, 0), + (10, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, caller, 0, 0), ] ), ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -47,7 +45,7 @@ def test_caller(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=2, @@ -58,7 +56,7 @@ def test_caller(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_coinbase.py b/tests/evm/test_coinbase.py index e82f43995..c63506fd0 100644 --- a/tests/evm/test_coinbase.py +++ b/tests/evm/test_coinbase.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,34 @@ Block, Bytecode, ) -from zkevm_specs.util import RLCStore, U160 +from zkevm_specs.util import rand_address, rand_fp, RLC, U160 -TESTING_DATA = ((Opcode.COINBASE, 0x030201),) +TESTING_DATA = (0x030201, rand_address()) -@pytest.mark.parametrize("opcode, address", TESTING_DATA) -def test_coinbase(opcode: Opcode, address: U160): - rlc_store = RLCStore() +@pytest.mark.parametrize("coinbase", TESTING_DATA) +def test_coinbase(coinbase: U160): + randomness = rand_fp() - coinbase_rlc = rlc_store.to_rlc(address.to_bytes(20, "little")) + block = Block(coinbase=coinbase) + + bytecode = Bytecode().coinbase() + bytecode_hash = RLC(bytecode.hash(), randomness) - bytecode = Bytecode(f"{opcode.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) - block = Block(address) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, coinbase_rlc, 0, 0), + (9, RW.Write, RWTableTag.Stack, 1, 1023, RLC(coinbase, randomness, 20), 0, 0), ] ), ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -47,7 +46,7 @@ def test_coinbase(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=2, @@ -58,7 +57,7 @@ def test_coinbase(opcode: Opcode, address: U160): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_end_block.py b/tests/evm/test_end_block.py new file mode 100644 index 000000000..21873f85b --- /dev/null +++ b/tests/evm/test_end_block.py @@ -0,0 +1,57 @@ +import pytest +from itertools import chain + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + RWTableTag, + RW, + CallContextFieldTag, + Block, + Transaction, +) +from zkevm_specs.util import rand_fp + +TESTING_DATA = (False, True) + + +@pytest.mark.parametrize("is_last_step", TESTING_DATA) +def test_end_block(is_last_step: bool): + randomness = rand_fp() + + tx = Transaction() + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(), + rw_table=set( + chain( + # dummy read/write for counting + [(i, *7 * [0]) for i in range(22)], + [(22, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0)] + if is_last_step + else [], + ) + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.EndBlock, + rw_counter=22, + call_id=1, + ), + StepState( + execution_state=ExecutionState.EndBlock, + rw_counter=22, + call_id=1, + ), + ], + end_with_last_step=is_last_step, + ) diff --git a/tests/evm/test_end_tx.py b/tests/evm/test_end_tx.py new file mode 100644 index 000000000..b69725937 --- /dev/null +++ b/tests/evm/test_end_tx.py @@ -0,0 +1,113 @@ +import pytest + +from zkevm_specs.evm import ( + ExecutionState, + StepState, + verify_steps, + Tables, + RWTableTag, + RW, + AccountFieldTag, + CallContextFieldTag, + Block, + Transaction, +) +from zkevm_specs.util import rand_fp, RLC, EMPTY_CODE_HASH, MAX_REFUND_QUOTIENT_OF_GAS_USED + +CALLEE_ADDRESS = 0xFF + +TESTING_DATA = ( + # Tx with non-capped refund + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=27000, gas_price=int(2e9)), + 994, + 4800, + False, + ), + # Tx with capped refund + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=65000, gas_price=int(2e9)), + 3952, + 38400, + False, + ), + # Last tx + ( + Transaction(caller_address=0xFE, callee_address=CALLEE_ADDRESS, gas=21000, gas_price=int(2e9)), + 0, + 0, + True, + ), +) + + +@pytest.mark.parametrize("tx, gas_left, refund, is_last_tx", TESTING_DATA) +def test_end_tx(tx: Transaction, gas_left: int, refund: int, is_last_tx: bool): + randomness = rand_fp() + + block = Block() + effective_refund = min(refund, (tx.gas - gas_left) // MAX_REFUND_QUOTIENT_OF_GAS_USED) + caller_balance_prev = int(1e18) - (tx.value + tx.gas * tx.gas_price) + caller_balance = caller_balance_prev + (gas_left + effective_refund) * tx.gas_price + coinbase_balance_prev = 0 + coinbase_balance = coinbase_balance_prev + (tx.gas - gas_left) * (tx.gas_price - block.base_fee) + + tables = Tables( + block_table=set(block.table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(), + rw_table=set( + [ + (17, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, tx.id, 0, 0), + (18, RW.Read, RWTableTag.TxRefund, tx.id, refund, refund, 0, 0), + ( + 19, + RW.Write, + RWTableTag.Account, + tx.caller_address, + AccountFieldTag.Balance, + RLC(caller_balance, randomness), + RLC(caller_balance_prev, randomness), + 0, + ), + ( + 20, + RW.Write, + RWTableTag.Account, + block.coinbase, + AccountFieldTag.Balance, + RLC(coinbase_balance, randomness), + RLC(coinbase_balance_prev, randomness), + 0, + ), + ] + + ( + [] + if is_last_tx + else [(21, RW.Read, RWTableTag.CallContext, 22, CallContextFieldTag.TxId, tx.id + 1, 0, 0)] + ) + ), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=[ + StepState( + execution_state=ExecutionState.EndTx, + rw_counter=17, + call_id=1, + is_root=True, + is_create=False, + code_source=RLC(EMPTY_CODE_HASH, randomness), + program_counter=0, + stack_pointer=1024, + gas_left=gas_left, + state_write_counter=2, + ), + StepState( + execution_state=ExecutionState.EndBlock if is_last_tx else ExecutionState.BeginTx, + rw_counter=22 - is_last_tx, + ), + ], + ) diff --git a/tests/evm/test_jump.py b/tests/evm/test_jump.py index 8e6cb1769..06ea99e65 100644 --- a/tests/evm/test_jump.py +++ b/tests/evm/test_jump.py @@ -1,6 +1,5 @@ import pytest -from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -12,7 +11,7 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, RLC TESTING_DATA = ((Opcode.JUMP, bytes([7])),) @@ -20,19 +19,19 @@ @pytest.mark.parametrize("opcode, dest_bytes", TESTING_DATA) def test_jump(opcode: Opcode, dest_bytes: bytes): - rlc_store = RLCStore() - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 # PUSH1 80 PUSH1 40 PUSH1 07 JUMP JUMPDEST STOP - bytecode = Bytecode(f"6080604060{dest_bytes.hex()}565b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(0x40).push1(dest_bytes).jump().jumpdest().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -41,7 +40,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -50,7 +49,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=8, @@ -61,7 +60,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=int.from_bytes(dest_bytes, "little"), stack_pointer=1022, gas_left=0, diff --git a/tests/evm/test_jumpi.py b/tests/evm/test_jumpi.py index 221796fee..6c7105c7b 100644 --- a/tests/evm/test_jumpi.py +++ b/tests/evm/test_jumpi.py @@ -1,6 +1,5 @@ import pytest -from typing import Optional from zkevm_specs.evm import ( ExecutionState, StepState, @@ -12,7 +11,7 @@ Block, Bytecode, ) -from zkevm_specs.util import hex_to_word, rand_bytes, RLCStore +from zkevm_specs.util import rand_fp, RLC TESTING_DATA = ((Opcode.JUMPI, bytes([40]), bytes([7])),) @@ -20,20 +19,20 @@ @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA) def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - rlc_store = RLCStore() - cond = rlc_store.to_rlc(bytes(reversed(cond_bytes))) - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + cond = RLC(bytes(reversed(cond_bytes)), randomness) + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 because the condition (40) is nonzero. # PUSH1 80 PUSH1 40 PUSH1 07 JUMPI JUMPDEST STOP - bytecode = Bytecode(f"6080604060{dest_bytes.hex()}575b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(0x40).push1(dest_bytes).jumpi().jumpdest().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -43,7 +42,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -52,7 +51,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=10, @@ -63,7 +62,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=int.from_bytes(dest_bytes, "little"), stack_pointer=1023, gas_left=0, @@ -77,20 +76,20 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA_ZERO_COND) def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - rlc_store = RLCStore() - cond = rlc_store.to_rlc(bytes(reversed(cond_bytes))) - dest = rlc_store.to_rlc(bytes(reversed(dest_bytes))) + randomness = rand_fp() + cond = RLC(bytes(reversed(cond_bytes)), randomness) + dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() # Jumps to PC=7 because the condition (0) is zero. # PUSH1 80 PUSH1 0 PUSH1 08 JUMPI STOP - bytecode = Bytecode(f"6080600060{dest_bytes.hex()}575b00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push1(0x80).push1(cond_bytes).push1(dest_bytes).jumpi().stop() + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(block.table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (9, RW.Read, RWTableTag.Stack, 1, 1021, dest, 0, 0), @@ -100,7 +99,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -109,7 +108,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=6, stack_pointer=1021, gas_left=10, @@ -120,7 +119,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=7, stack_pointer=1023, gas_left=0, diff --git a/tests/evm/test_push.py b/tests/evm/test_push.py index d663a2aaf..252b11a17 100644 --- a/tests/evm/test_push.py +++ b/tests/evm/test_push.py @@ -3,7 +3,6 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, RWTableTag, @@ -11,34 +10,33 @@ Block, Bytecode, ) -from zkevm_specs.util import rand_bytes, RLCStore +from zkevm_specs.util import rand_bytes, rand_fp, RLC TESTING_DATA = tuple( [ - (Opcode.PUSH1, bytes([1])), - (Opcode.PUSH2, bytes([2, 1])), - (Opcode.PUSH31, bytes([i for i in range(31, 0, -1)])), - (Opcode.PUSH32, bytes([i for i in range(32, 0, -1)])), + (bytes([1])), + (bytes([2, 1])), + (bytes([i for i in range(31, 0, -1)])), + (bytes([i for i in range(32, 0, -1)])), ] - + [(Opcode(Opcode.PUSH1 + i), rand_bytes(i + 1)) for i in range(32)] + + [(rand_bytes(i + 1)) for i in range(32)] ) -@pytest.mark.parametrize("opcode, value_be_bytes", TESTING_DATA) -def test_push(opcode: Opcode, value_be_bytes: bytes): - rlc_store = RLCStore() +@pytest.mark.parametrize("value_be_bytes", TESTING_DATA) +def test_push(value_be_bytes: bytes): + randomness = rand_fp() - value = rlc_store.to_rlc(bytes(reversed(value_be_bytes))) + value = RLC(bytes(reversed(value_be_bytes)), randomness) - block = Block() - bytecode = Bytecode(f"{opcode.hex()}{value_be_bytes.hex()}00") - bytecode_hash = rlc_store.to_rlc(bytecode.hash, 32) + bytecode = Bytecode().push(value_be_bytes, n_bytes=len(value_be_bytes)) + bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(block.table_assignments(rlc_store)), + block_table=set(Block().table_assignments(randomness)), tx_table=set(), - bytecode_table=set(bytecode.table_assignments(rlc_store)), + bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( [ (8, RW.Write, RWTableTag.Stack, 1, 1023, value, 0, 0), @@ -47,7 +45,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): ) verify_steps( - rlc_store=rlc_store, + randomness=randomness, tables=tables, steps=[ StepState( @@ -56,7 +54,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=0, stack_pointer=1024, gas_left=3, @@ -67,7 +65,7 @@ def test_push(opcode: Opcode, value_be_bytes: bytes): call_id=1, is_root=True, is_create=False, - opcode_source=bytecode_hash, + code_source=bytecode_hash, program_counter=1 + len(value_be_bytes), stack_pointer=1023, gas_left=0, diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index cc7308e7f..99c0a0739 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -3,22 +3,22 @@ import traceback from copy import deepcopy from zkevm_specs.evm import Bytecode -from zkevm_specs.util import RLCStore +from zkevm_specs.util import RLC, rand_fp # Unroll the bytecode -def unroll(bytecode, rlc_store): - return UnrolledBytecode(bytecode, list(Bytecode(bytecode).table_assignments(rlc_store))) +def unroll(bytecode, randomness): + return UnrolledBytecode(bytecode, list(Bytecode(bytecode).table_assignments(randomness))) # Verify the bytecode circuit with the given data -def verify(k, bytecodes, rlc_store, success): +def verify(k, bytecodes, randomness, success): push_table = assign_push_table() - keccak_table = assign_keccak_table(map(lambda v: v.bytes, bytecodes), rlc_store) - rows = assign_bytecode_circuit(k, bytecodes, rlc_store) + keccak_table = assign_keccak_table(map(lambda v: v.bytes, bytecodes), randomness) + rows = assign_bytecode_circuit(k, bytecodes, randomness) try: for (idx, row) in enumerate(rows): prev_row = rows[(idx - 1) % len(rows)] - check_bytecode_row(row, prev_row, push_table, keccak_table, rlc_store.randomness) + check_bytecode_row(row, prev_row, push_table, keccak_table, randomness) ok = True except AssertionError as e: if success: @@ -28,7 +28,7 @@ def verify(k, bytecodes, rlc_store, success): k = 10 -rlc_store = RLCStore() +randomness = rand_fp() def test_bytecode_unrolling(): @@ -48,110 +48,104 @@ def test_bytecode_unrolling(): for _ in range(n): rows.append((0, len(rows), data_byte, False)) # Set the hash of the complete bytecode in the rows - hash = rlc_store.to_rlc(keccak256(bytes(bytecode)), 32) + hash = RLC(bytes(reversed(keccak256(bytes(bytecode)))), randomness) for i in range(len(rows)): rows[i] = (hash, rows[i][1], rows[i][2], rows[i][3]) # Unroll the bytecode - unrolled = unroll(bytes(bytecode), rlc_store) + unrolled = unroll(bytes(bytecode), randomness) # Check if the bytecode was unrolled correctly assert UnrolledBytecode(bytes(bytecode), rows) == unrolled # Verify the unrolling in the circuit - verify(k, [unrolled], rlc_store, True) + verify(k, [unrolled], randomness, True) def test_bytecode_empty(): - bytecodes = [ - unroll(bytes([]), rlc_store), - ] - verify(k, bytecodes, rlc_store, True) + bytecodes = [unroll(bytes([]), randomness)] + verify(k, bytecodes, randomness, True) def test_bytecode_full(): - bytecodes = [ - unroll(bytes([7] * 2 ** k), rlc_store), - ] - verify(k, bytecodes, rlc_store, True) + bytecodes = [unroll(bytes([7] * 2 ** k), randomness)] + verify(k, bytecodes, randomness, True) def test_bytecode_incomplete(): - bytecodes = [ - unroll(bytes([7] * (2 ** k + 1)), rlc_store), - ] - verify(k, bytecodes, rlc_store, False) + bytecodes = [unroll(bytes([7] * (2 ** k + 1)), randomness)] + verify(k, bytecodes, randomness, False) def test_bytecode_multiple(): bytecodes = [ - unroll(bytes([]), rlc_store), - unroll(bytes([Opcode.PUSH32]), rlc_store), - unroll(bytes([Opcode.PUSH32, Opcode.ADD]), rlc_store), - unroll(bytes([Opcode.ADD, Opcode.PUSH32]), rlc_store), - unroll(bytes([Opcode.ADD, Opcode.PUSH32, Opcode.ADD]), rlc_store), + unroll(bytes([]), randomness), + unroll(bytes([Opcode.PUSH32]), randomness), + unroll(bytes([Opcode.PUSH32, Opcode.ADD]), randomness), + unroll(bytes([Opcode.ADD, Opcode.PUSH32]), randomness), + unroll(bytes([Opcode.ADD, Opcode.PUSH32, Opcode.ADD]), randomness), ] - verify(k, bytecodes, rlc_store, True) + verify(k, bytecodes, randomness, True) def test_bytecode_invalid_hash_data(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Change the hash on the first position invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): invalid.rows[idx] = (1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_index(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Start the index at 1 invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (row[0] + 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[idx] = (row[0].value + 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) row = unrolled.rows[-1] - invalid.rows[-1] = (row[0] - 1, row[1], row[2], row[3]) - verify(k, [invalid], rlc_store, False) + invalid.rows[-1] = (row[0].value - 1, row[1], row[2], row[3]) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_byte_data(): - unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), rlc_store) - verify(k, [unrolled], rlc_store, True) + unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) + verify(k, [unrolled], randomness, True) # Change the first byte invalid = deepcopy(unrolled) row = unrolled.rows[0] invalid.rows[0] = (row[0], row[1], row[2], 9) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Change a byte on another position invalid = deepcopy(unrolled) row = unrolled.rows[5] invalid.rows[5] = (row[0], row[1], row[2], 6) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Set a byte value out of range invalid = deepcopy(unrolled) row = unrolled.rows[3] invalid.rows[3] = (row[0], row[1], row[2], 256) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) def test_bytecode_invalid_is_code(): @@ -167,24 +161,24 @@ def test_bytecode_invalid_is_code(): Opcode.PUSH6, ] ), - rlc_store, + randomness, ) - verify(k, [unrolled], rlc_store, True) + verify(k, [unrolled], randomness, True) # Mark the 3rd byte as code (is push data from the first PUSH1) invalid = deepcopy(unrolled) row = unrolled.rows[2] invalid.rows[2] = (row[0], row[1], 1, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Mark the 4rd byte as data (is code) invalid = deepcopy(unrolled) row = unrolled.rows[3] invalid.rows[3] = (row[0], row[1], 0, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False) # Mark the 7th byte as code (is data for the PUSH7) invalid = deepcopy(unrolled) row = unrolled.rows[6] invalid.rows[6] = (row[0], row[1], 1, row[3]) - verify(k, [invalid], rlc_store, False) + verify(k, [invalid], randomness, False)