diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5a0007af9..22f778391 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,5 +27,7 @@ jobs: make install - name: Lint run: make lint + - name: Type check + run: make type - name: Test with pytest run: make test diff --git a/Makefile b/Makefile index d90c6de50..32c551cdd 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ lint: ## Check whether the code is formated correctly black . --check mdformat specs/ --number --check +type: ## Check the typing of the Python code + mypy src + test: ## Run tests pytest diff --git a/setup.cfg b/setup.cfg index e77608300..81a802224 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ test = lint = black >= 22.1.0 mdformat >= 0.7.13 + mypy >= 0.931 diff --git a/src/zkevm_specs/__init__.py b/src/zkevm_specs/__init__.py index 668f399ba..c302c1579 100644 --- a/src/zkevm_specs/__init__.py +++ b/src/zkevm_specs/__init__.py @@ -1,3 +1,4 @@ +from . import bytecode from . import encoding from . import evm from . import opcode diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 8b7ef60b4..1fee53213 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -1,7 +1,7 @@ -from typing import Sequence, Union, Tuple, Set +from typing import Sequence, Union, Tuple, Set, NamedTuple from collections import namedtuple from .util import keccak256, FQ, RLC -from .evm.opcode import get_push_size +from .evm import get_push_size, BytecodeTableRow from .encoding import U8, U256, is_circuit_code # Row in the circuit @@ -10,7 +10,9 @@ "q_first q_last hash index byte is_code push_data_left hash_rlc hash_length byte_push_size is_final padding", ) # Unrolled bytecode -UnrolledBytecode = namedtuple("UnrolledBytecode", "bytes rows") +class UnrolledBytecode(NamedTuple): + bytes: bytes + rows: Sequence[BytecodeTableRow] @is_circuit_code @@ -24,7 +26,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code @@ -107,21 +109,21 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando for idx, row in enumerate(bytecode.rows): # Track which byte is an opcode and which is push data is_code = push_data_left == 0 - byte_push_size = get_push_size(row[2]) + byte_push_size = get_push_size(row.byte) push_data_left = byte_push_size if is_code else push_data_left - 1 # Add the byte to the accumulator - hash_rlc = hash_rlc * randomness + row[2] + hash_rlc = hash_rlc * randomness + row.byte # Set the data for this row rows.append( Row( offset == 0, offset == last_row_offset, - row[0], - row[1], - row[2], - row[3], + row.bytecode_hash, + row.index, + row.byte, + row.is_code, push_data_left, hash_rlc, len(bytecode.bytes), @@ -178,10 +180,10 @@ def assign_push_table(): # Generate keccak table -def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int): +def assign_keccak_table(bytecodes: Sequence[bytes], randomness: FQ): keccak_table = [] for bytecode in bytecodes: hash = RLC(bytes(reversed(keccak256(bytecode))), randomness) rlc = RLC(bytes(reversed(bytecode)), randomness, len(bytecode)) - keccak_table.append((rlc, len(bytecode), hash)) + keccak_table.append((rlc.expr(), len(bytecode), hash.expr())) return _convert_table(keccak_table) diff --git a/src/zkevm_specs/encoding/lookup.py b/src/zkevm_specs/encoding/lookup.py index 774b892dd..808ab0d58 100644 --- a/src/zkevm_specs/encoding/lookup.py +++ b/src/zkevm_specs/encoding/lookup.py @@ -2,7 +2,7 @@ class LookupTable: - columns: Tuple[str] + columns: Tuple[str, ...] rows: Set[Tuple[int, ...]] def __init__(self, columns: Sequence[str]) -> None: diff --git a/src/zkevm_specs/encoding/utils.py b/src/zkevm_specs/encoding/utils.py index 9c0db0cd2..1b64996ee 100644 --- a/src/zkevm_specs/encoding/utils.py +++ b/src/zkevm_specs/encoding/utils.py @@ -1,8 +1,8 @@ -from typing import Sequence, Tuple, List +from typing import Sequence, Tuple from .typing import U8, U256, U64 -def is_circuit_code(func) -> object: +def is_circuit_code(func): """ A no-op decorator just to mark the function """ @@ -15,19 +15,19 @@ def wrapper(*args, **kargs): def u256_to_u8s(x: U256) -> Tuple[U8, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 8 * i) & 0xFF for i in range(32)) + return tuple(U8((x >> 8 * i) & 0xFF) for i in range(32)) def u256_to_u64s(x: U256) -> Tuple[U64, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF for i in range(4)) + return tuple(U64((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)) def u8s_to_u256(xs: Sequence[U8]) -> U256: assert len(xs) == 32 for u8 in xs: assert 0 <= u8 <= 255 - return sum(x * (2 ** (8 * i)) for i, x in enumerate(xs)) + return U256(sum(x * (2 ** (8 * i)) for i, x in enumerate(xs))) # [u8;32]->[u64;4] @@ -37,5 +37,5 @@ def u8s_to_u64s(xs: Sequence[U8]) -> Tuple[U64, ...]: A = [u64_0] * 4 # A = A3A2A1A0 for i in range(4): for j in range(8): - A[i] += U64(xs[j + 8 * i] * (2 ** (8 * j))) + A[i] += xs[j + 8 * i] * (2 ** (8 * j)) return tuple(A) diff --git a/src/zkevm_specs/evm/__init__.py b/src/zkevm_specs/evm/__init__.py index 525240e62..ab5b82f32 100644 --- a/src/zkevm_specs/evm/__init__.py +++ b/src/zkevm_specs/evm/__init__.py @@ -6,4 +6,4 @@ from .step import * from .table import * from .typing import * -from . import util +from .util import * diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index 58de3305c..d9a2a1774 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,6 +1,6 @@ -from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH, FQ, RLC, cast_expr from ..execution_state import ExecutionState -from ..instruction import Instruction, Transition +from ..instruction import Instruction, ReversionInfo, Transition from ..precompiled import PrecompiledAddress from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag @@ -9,28 +9,23 @@ def begin_tx(instruction: Instruction): call_id = instruction.curr.rw_counter tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=call_id) - rw_counter_end_of_reversion = instruction.call_context_lookup( - CallContextFieldTag.RwCounterEndOfReversion, call_id=call_id - ) - is_persistent = instruction.call_context_lookup( - CallContextFieldTag.IsPersistent, call_id=call_id - ) + reversion_info = instruction.reversion_info(call_id=call_id) if instruction.is_first_step: - instruction.constrain_equal(instruction.curr.rw_counter, 1) - instruction.constrain_equal(tx_id, 1) + instruction.constrain_equal(instruction.curr.rw_counter, FQ(1)) + instruction.constrain_equal(tx_id, FQ(1)) tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress) tx_callee_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CalleeAddress) tx_is_create = instruction.tx_context_lookup(tx_id, TxContextFieldTag.IsCreate) - tx_value = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Value) + tx_value = cast_expr(instruction.tx_context_lookup(tx_id, TxContextFieldTag.Value), RLC) tx_call_data_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength) # Verify nonce tx_nonce = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Nonce) nonce, nonce_prev = instruction.account_write(tx_caller_address, AccountFieldTag.Nonce) instruction.constrain_equal(tx_nonce, nonce_prev) - instruction.constrain_equal(nonce, nonce_prev + 1) + instruction.constrain_equal(nonce, nonce_prev.expr() + 1) # TODO: Implement EIP 1559 (currently it supports legacy transaction format) # Calculate gas fee @@ -42,15 +37,19 @@ def begin_tx(instruction: Instruction): # TODO: Handle gas cost of tx level access list (EIP 2930) tx_call_data_gas_cost = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataGasCost) gas_left = ( - tx_gas + tx_gas.expr() - (GAS_COST_CREATION_TX if tx_is_create == 1 else GAS_COST_TX) - - tx_call_data_gas_cost + - tx_call_data_gas_cost.expr() ) instruction.constrain_gas_left_not_underflow(gas_left) # Prepare access list of caller and callee - instruction.constrain_equal(instruction.add_account_to_access_list(tx_id, tx_caller_address), 1) - instruction.constrain_equal(instruction.add_account_to_access_list(tx_id, tx_callee_address), 1) + instruction.constrain_equal( + instruction.add_account_to_access_list(tx_id, tx_caller_address), FQ(1) + ) + instruction.constrain_equal( + instruction.add_account_to_access_list(tx_id, tx_callee_address), FQ(1) + ) # Verify transfer instruction.transfer_with_gas_fee( @@ -58,8 +57,7 @@ def begin_tx(instruction: Instruction): tx_callee_address, tx_value, gas_fee, - is_persistent, - rw_counter_end_of_reversion, + reversion_info, ) if tx_is_create == 1: @@ -79,16 +77,16 @@ def begin_tx(instruction: Instruction): # - TxId is checked from previous step or constraint to 1 if is_first_step # - IsSuccess, IsPersistent will be verified in the end of tx for (tag, value) in [ - (CallContextFieldTag.Depth, 1), + (CallContextFieldTag.Depth, FQ(1)), (CallContextFieldTag.CallerAddress, tx_caller_address), (CallContextFieldTag.CalleeAddress, tx_callee_address), - (CallContextFieldTag.CallDataOffset, 0), + (CallContextFieldTag.CallDataOffset, FQ(0)), (CallContextFieldTag.CallDataLength, tx_call_data_length), (CallContextFieldTag.Value, tx_value), - (CallContextFieldTag.IsStatic, False), - (CallContextFieldTag.LastCalleeId, 0), - (CallContextFieldTag.LastCalleeReturnDataOffset, 0), - (CallContextFieldTag.LastCalleeReturnDataLength, 0), + (CallContextFieldTag.IsStatic, FQ(False)), + (CallContextFieldTag.LastCalleeId, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataOffset, FQ(0)), + (CallContextFieldTag.LastCalleeReturnDataLength, FQ(0)), ]: instruction.constrain_equal( instruction.call_context_lookup(tag, call_id=call_id), value @@ -104,9 +102,13 @@ def begin_tx(instruction: Instruction): state_write_counter=Transition.to(2), ) + assert instruction.next is not None + # Constrain either: # - is_empty_code and is_to_end_tx # - (not is_empty_code) and (not is_to_end_tx) - is_empty_code = instruction.is_equal(code_hash, instruction.int_to_rlc(EMPTY_CODE_HASH, 32)) + is_empty_code = instruction.is_equal( + code_hash, RLC(EMPTY_CODE_HASH, instruction.randomness) + ) is_to_end_tx = instruction.is_equal(instruction.next.execution_state, ExecutionState.EndTx) instruction.constrain_equal(is_empty_code + is_to_end_tx, 2 * is_empty_code * is_to_end_tx) diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 55d0b3b20..1f769911d 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -1,3 +1,4 @@ +from ...util.param import N_BYTES_ACCOUNT_ADDRESS from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag from ..opcode import Opcode @@ -6,18 +7,14 @@ def coinbase(instruction: Instruction): opcode = instruction.opcode_lookup(True) instruction.constrain_equal(opcode, Opcode.COINBASE) - address = instruction.stack_push() + # in real circuit also check address raw data is 160 bit length (20 bytes) # check block table for coinbase address instruction.constrain_equal( - address, - instruction.int_to_rlc( - instruction.block_context_lookup(BlockContextFieldTag.Coinbase), - # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 - # byte RLC constraint in all places. See: - # https://github.com/appliedzkp/zkevm-specs/issues/101 - 20, - ), + instruction.block_context_lookup(BlockContextFieldTag.Coinbase), + # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 byte RLC constraint in + # all places. See: https://github.com/appliedzkp/zkevm-specs/issues/101 + instruction.rlc_to_fq_exact(instruction.stack_push(), N_BYTES_ACCOUNT_ADDRESS), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/block_timestamp.py b/src/zkevm_specs/evm/execution/block_timestamp.py index c503e2750..3f4c16202 100644 --- a/src/zkevm_specs/evm/execution/block_timestamp.py +++ b/src/zkevm_specs/evm/execution/block_timestamp.py @@ -6,11 +6,11 @@ def timestamp(instruction: Instruction): opcode = instruction.opcode_lookup(True) instruction.constrain_equal(opcode, Opcode.TIMESTAMP) - timestamp = instruction.stack_push() + # check block table for timestamp instruction.constrain_equal( - timestamp, - instruction.int_to_rlc(instruction.block_context_lookup(BlockContextFieldTag.Timestamp), 8), + instruction.block_context_lookup(BlockContextFieldTag.Timestamp), + instruction.rlc_to_fq_exact(instruction.stack_push(), 8), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/calldatacopy.py b/src/zkevm_specs/evm/execution/calldatacopy.py index 599b3dc58..268e8535e 100644 --- a/src/zkevm_specs/evm/execution/calldatacopy.py +++ b/src/zkevm_specs/evm/execution/calldatacopy.py @@ -1,24 +1,24 @@ -from ...util import N_BYTES_MEMORY_ADDRESS, FQ +from ...util import N_BYTES_MEMORY_ADDRESS, FQ, Expression from ..execution_state import ExecutionState from ..instruction import Instruction, Transition -from ..table import RW, FixedTableTag, RWTableTag, CallContextFieldTag, TxContextFieldTag +from ..table import RW, CallContextFieldTag, TxContextFieldTag def calldatacopy(instruction: Instruction): opcode = instruction.opcode_lookup(True) - memory_offset = instruction.stack_pop() - data_offset = instruction.stack_pop() - length = instruction.stack_pop() + memory_offset_word = instruction.stack_pop() + data_offset_word = instruction.stack_pop() + length_word = instruction.stack_pop() # convert rlc to FQ - memory_offset, length = instruction.memory_offset_and_length(memory_offset, length) - data_offset = instruction.rlc_to_fq_exact(data_offset, N_BYTES_MEMORY_ADDRESS) + memory_offset, length = instruction.memory_offset_and_length(memory_offset_word, length_word) + data_offset = instruction.rlc_to_fq_exact(data_offset_word, N_BYTES_MEMORY_ADDRESS) tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId, RW.Read) if instruction.curr.is_root: call_data_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength) - call_data_offset = FQ.zero() + call_data_offset: Expression = FQ.zero() else: call_data_length = instruction.call_context_lookup( CallContextFieldTag.CallDataLength, RW.Read @@ -34,12 +34,15 @@ def calldatacopy(instruction: Instruction): # When length != 0, constrain the state in the next execution state CopyToMemory if not instruction.is_zero(length): + assert instruction.next is not None instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) next_aux = instruction.next.aux_data instruction.constrain_equal(next_aux.src_addr, data_offset + call_data_offset) instruction.constrain_equal(next_aux.dst_addr, memory_offset) - instruction.constrain_equal(next_aux.src_addr_end, call_data_length + call_data_offset) - instruction.constrain_equal(next_aux.from_tx, instruction.curr.is_root) + instruction.constrain_equal( + next_aux.src_addr_end, call_data_length.expr() + call_data_offset + ) + instruction.constrain_equal(next_aux.from_tx, FQ(instruction.curr.is_root)) instruction.constrain_equal(next_aux.tx_id, tx_id) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/calldataload.py b/src/zkevm_specs/evm/execution/calldataload.py index c2361cbc3..1c4751310 100644 --- a/src/zkevm_specs/evm/execution/calldataload.py +++ b/src/zkevm_specs/evm/execution/calldataload.py @@ -1,8 +1,8 @@ +from ...util import FQ, RLC, Expression, N_BYTES_WORD from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import RW, CallContextFieldTag, TxContextFieldTag from ..util import BufferReaderGadget -from ...util.param import N_BYTES_WORD def calldataload(instruction: Instruction): @@ -16,37 +16,36 @@ def calldataload(instruction: Instruction): if instruction.curr.is_root: calldata_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength) - calldata_offset = 0 + calldata_offset: Expression = FQ(0) else: calldata_length = instruction.call_context_lookup(CallContextFieldTag.CallDataLength) calldata_offset = instruction.call_context_lookup(CallContextFieldTag.CallDataOffset) caller_id = instruction.call_context_lookup(CallContextFieldTag.CallerId) src_addr = offset + calldata_offset - src_addr_end = calldata_length + calldata_offset + src_addr_end = calldata_length.expr() + calldata_offset.expr() buffer_reader = BufferReaderGadget( - instruction, N_BYTES_WORD, src_addr, src_addr_end, N_BYTES_WORD + instruction, N_BYTES_WORD, src_addr, src_addr_end, FQ(N_BYTES_WORD) ) calldata_word = [] for idx in range(N_BYTES_WORD): - if buffer_reader.read_flag(idx): + if buffer_reader.read_flag(idx) == FQ(1): if instruction.curr.is_root: tx_byte = instruction.tx_calldata_lookup(tx_id, src_addr + idx) buffer_reader.constrain_byte(idx, tx_byte) - calldata_word.append(int(tx_byte)) + calldata_word.append(tx_byte.expr().n) else: mem_byte = instruction.memory_lookup(RW.Read, src_addr + idx, caller_id) buffer_reader.constrain_byte(idx, mem_byte) - calldata_word.append(int(mem_byte)) + calldata_word.append(mem_byte.expr().n) else: - buffer_reader.constrain_byte(idx, 0) calldata_word.append(0) instruction.constrain_equal( instruction.stack_push(), - instruction.bytes_to_rlc(bytes(calldata_word)), + RLC(bytes(calldata_word), instruction.randomness), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/calldatasize.py b/src/zkevm_specs/evm/execution/calldatasize.py index 2495577a0..21b5bc7d1 100644 --- a/src/zkevm_specs/evm/execution/calldatasize.py +++ b/src/zkevm_specs/evm/execution/calldatasize.py @@ -1,23 +1,20 @@ +from ...util import N_BYTES_MEMORY_ADDRESS from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode -from ...util.param import N_BYTES_MEMORY_ADDRESS def calldatasize(instruction: Instruction): opcode = instruction.opcode_lookup(True) + instruction.constrain_equal(opcode, Opcode.CALLDATASIZE) # check [rw_table, call_context] table for call data length and compare # against stack top after push. instruction.constrain_equal( - instruction.int_to_rlc( - instruction.call_context_lookup(CallContextFieldTag.CallDataLength), - # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 - # byte RLC constraint in all places. See: - # https://github.com/appliedzkp/zkevm-specs/issues/101 - N_BYTES_MEMORY_ADDRESS, - ), - instruction.stack_push(), + instruction.call_context_lookup(CallContextFieldTag.CallDataLength), + # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 byte RLC constraint in + # all places. See: https://github.com/appliedzkp/zkevm-specs/issues/101 + instruction.rlc_to_fq_exact(instruction.stack_push(), N_BYTES_MEMORY_ADDRESS), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 58bf046ea..970cdec98 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -1,7 +1,7 @@ +from ...util import N_BYTES_ACCOUNT_ADDRESS from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode -from ...util.param import N_BYTES_ACCOUNT_ADDRESS def caller(instruction: Instruction): @@ -11,14 +11,10 @@ def caller(instruction: Instruction): # check [rw_table, call_context] table for caller address and compare with # stack top after push instruction.constrain_equal( - instruction.int_to_rlc( - instruction.call_context_lookup(CallContextFieldTag.CallerAddress), - # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 - # byte RLC constraint in all places. See: - # https://github.com/appliedzkp/zkevm-specs/issues/101 - N_BYTES_ACCOUNT_ADDRESS, - ), - instruction.stack_push(), + instruction.call_context_lookup(CallContextFieldTag.CallerAddress), + # NOTE: We can replace this with N_BYTES_WORD if we reuse the 32 byte RLC constraint in + # all places. See: https://github.com/appliedzkp/zkevm-specs/issues/101 + instruction.rlc_to_fq_exact(instruction.stack_push(), N_BYTES_ACCOUNT_ADDRESS), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/end_block.py b/src/zkevm_specs/evm/execution/end_block.py index 4ba98a1e9..f41070f7b 100644 --- a/src/zkevm_specs/evm/execution/end_block.py +++ b/src/zkevm_specs/evm/execution/end_block.py @@ -1,3 +1,4 @@ +from ...util import FQ from ..instruction import Instruction, Transition from ..table import CallContextFieldTag @@ -12,7 +13,7 @@ def end_block(instruction: Instruction): total_tx = instruction.call_context_lookup(CallContextFieldTag.TxId) instruction.constrain_equal( total_tx, - max([tx_id for tx_id, *_ in instruction.tables.tx_table]), + FQ(max([row.tx_id.expr().n for row in instruction.tables.tx_table])), ) # Verify rw_counter counts to identical rw amount in rw_table to ensure @@ -20,7 +21,7 @@ def end_block(instruction: Instruction): total_rw = instruction.curr.rw_counter + 1 # extra 1 from the tx_id lookup instruction.constrain_equal( total_rw, - len(instruction.tables.rw_table), + FQ(len(instruction.tables.rw_table)), ) else: # Propagate rw_counter and call_id all the way down diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py index e9d295b32..d61a5d218 100644 --- a/src/zkevm_specs/evm/execution/end_tx.py +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -1,4 +1,4 @@ -from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED +from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED, FQ, RLC, cast_expr from ..execution_state import ExecutionState from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag, CallContextFieldTag, TxContextFieldTag @@ -11,7 +11,7 @@ def end_tx(instruction: Instruction): tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas) gas_used = tx_gas - instruction.curr.gas_left max_refund, _ = instruction.constant_divmod( - gas_used, MAX_REFUND_QUOTIENT_OF_GAS_USED, N_BYTES_GAS + gas_used, FQ(MAX_REFUND_QUOTIENT_OF_GAS_USED), N_BYTES_GAS ) refund = instruction.tx_refund_read(tx_id) effective_refund = instruction.min(max_refund, refund, 8) @@ -26,28 +26,28 @@ def end_tx(instruction: Instruction): instruction.add_balance(tx_caller_address, [value]) # Add gas_used * effective_tip to coinbase's balance - base_fee = instruction.block_context_lookup(BlockContextFieldTag.BaseFee) + base_fee = cast_expr(instruction.block_context_lookup(BlockContextFieldTag.BaseFee), RLC) effective_tip, _ = instruction.sub_word(tx_gas_price, base_fee) reward, carry = instruction.mul_word_by_u64(effective_tip, gas_used) instruction.constrain_zero(carry) coinbase = instruction.block_context_lookup(BlockContextFieldTag.Coinbase) instruction.add_balance(coinbase, [reward]) - # Go to next transaction + assert instruction.next is not None + + # When to next transaction if instruction.next.execution_state == ExecutionState.BeginTx: # Check next tx_id is increased by 1 instruction.constrain_equal( instruction.call_context_lookup( CallContextFieldTag.TxId, call_id=instruction.next.rw_counter ), - tx_id + 1, + tx_id.expr() + 1, ) - # Do step state transition for rw_counter instruction.constrain_step_state_transition(rw_counter=Transition.delta(5)) - # Go to end of block - elif instruction.next.execution_state == ExecutionState.EndBlock: + + # When to end of block + if instruction.next.execution_state == ExecutionState.EndBlock: # Do step state transition for rw_counter instruction.constrain_step_state_transition(rw_counter=Transition.delta(4)) - else: - raise ValueError("Unreacheable") diff --git a/src/zkevm_specs/evm/execution/gas.py b/src/zkevm_specs/evm/execution/gas.py index d0a0faea5..e74e5f60b 100644 --- a/src/zkevm_specs/evm/execution/gas.py +++ b/src/zkevm_specs/evm/execution/gas.py @@ -1,6 +1,6 @@ +from ...util import N_BYTES_GAS from ..instruction import Instruction, Transition from ..opcode import Opcode -from ..table import CallContextFieldTag, TxContextFieldTag def gas(instruction: Instruction): @@ -8,8 +8,7 @@ def gas(instruction: Instruction): instruction.constrain_equal(opcode, Opcode.GAS) # fetch gas from rw table and consider only the lower 8 bytes (uint64) - gas = instruction.rlc_to_le_bytes(instruction.stack_push()) - gas = int.from_bytes(gas[0:8], "little") + gas = instruction.rlc_to_fq_exact(instruction.stack_push(), N_BYTES_GAS) instruction.constrain_equal( gas, diff --git a/src/zkevm_specs/evm/execution/jump.py b/src/zkevm_specs/evm/execution/jump.py index 8e46d6d1f..e68ea02ad 100644 --- a/src/zkevm_specs/evm/execution/jump.py +++ b/src/zkevm_specs/evm/execution/jump.py @@ -15,7 +15,6 @@ def jump(instruction: Instruction): dest_value = instruction.rlc_to_fq_exact(dest, N_BYTES_PROGRAM_COUNTER) # Verify `dest` is code within byte code table - # assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True) instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True)) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/jumpi.py b/src/zkevm_specs/evm/execution/jumpi.py index c485b0b74..d680c6404 100644 --- a/src/zkevm_specs/evm/execution/jumpi.py +++ b/src/zkevm_specs/evm/execution/jumpi.py @@ -1,4 +1,4 @@ -from ...util.param import N_BYTES_PROGRAM_COUNTER +from ...util import FQ, N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -14,7 +14,7 @@ def jumpi(instruction: Instruction): # check `cond` is zero or not if instruction.is_zero(cond): - pc_diff = 1 + pc_diff = FQ(1) else: # Get `dest` raw value in max 8 bytes dest_value = instruction.rlc_to_fq_exact(dest, N_BYTES_PROGRAM_COUNTER) diff --git a/src/zkevm_specs/evm/execution/memory_copy.py b/src/zkevm_specs/evm/execution/memory_copy.py index cabb93b71..7ce8044e0 100644 --- a/src/zkevm_specs/evm/execution/memory_copy.py +++ b/src/zkevm_specs/evm/execution/memory_copy.py @@ -1,8 +1,8 @@ -from ...util import FQ, N_BYTES_MEMORY_SIZE +from ...util import N_BYTES_MEMORY_SIZE, FQ, Expression from ..execution_state import ExecutionState from ..instruction import Instruction, Transition from ..step import CopyToMemoryAuxData -from ..table import RW, TxContextFieldTag +from ..table import RW from ..util import BufferReaderGadget @@ -17,19 +17,15 @@ def copy_to_memory(instruction: Instruction): instruction, MAX_COPY_BYTES, aux.src_addr, aux.src_addr_end, aux.bytes_left ) - data = [] - rw_counter_delta = 0 for i in range(MAX_COPY_BYTES): - if not buffer_reader.read_flag(i): - byte = FQ.zero() + if buffer_reader.read_flag(i) == 0: + byte: Expression = FQ(0) elif aux.from_tx == 1: - byte = instruction.tx_context_lookup( - aux.tx_id, TxContextFieldTag.CallData, aux.src_addr + i - ) + byte = instruction.tx_calldata_lookup(aux.tx_id, aux.src_addr + i) else: byte = instruction.memory_lookup(RW.Read, aux.src_addr + i) buffer_reader.constrain_byte(i, byte) - if buffer_reader.has_data(i): + if buffer_reader.has_data(i) == 1: instruction.constrain_equal(byte, instruction.memory_lookup(RW.Write, aux.dst_addr + i)) copied_bytes = buffer_reader.num_bytes() @@ -38,9 +34,12 @@ def copy_to_memory(instruction: Instruction): instruction.constrain_zero((1 - lt) * (1 - finished)) if finished == 0: - instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) + assert instruction.next is not None next_aux = instruction.next.aux_data - assert next_aux is not None and isinstance(next_aux, CopyToMemoryAuxData) + + assert isinstance(next_aux, CopyToMemoryAuxData) + + instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) instruction.constrain_equal(next_aux.src_addr, aux.src_addr + copied_bytes) instruction.constrain_equal(next_aux.dst_addr, aux.dst_addr + copied_bytes) instruction.constrain_equal(next_aux.bytes_left + copied_bytes, aux.bytes_left) diff --git a/src/zkevm_specs/evm/execution/origin.py b/src/zkevm_specs/evm/execution/origin.py index 47155b854..8295b189a 100644 --- a/src/zkevm_specs/evm/execution/origin.py +++ b/src/zkevm_specs/evm/execution/origin.py @@ -1,7 +1,7 @@ +from ...util import N_BYTES_ACCOUNT_ADDRESS from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import CallContextFieldTag, TxContextFieldTag -from ...util.param import N_BYTES_ACCOUNT_ADDRESS def origin(instruction: Instruction): @@ -11,11 +11,8 @@ def origin(instruction: Instruction): instruction.constrain_equal(opcode, Opcode.ORIGIN) instruction.constrain_equal( - instruction.int_to_rlc( - instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress), - N_BYTES_ACCOUNT_ADDRESS, - ), - instruction.stack_push(), + instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress), + instruction.rlc_to_fq_exact(instruction.stack_push(), N_BYTES_ACCOUNT_ADDRESS), ) instruction.step_state_transition_in_same_context( diff --git a/src/zkevm_specs/evm/execution/push.py b/src/zkevm_specs/evm/execution/push.py index cb5602684..043c5adff 100644 --- a/src/zkevm_specs/evm/execution/push.py +++ b/src/zkevm_specs/evm/execution/push.py @@ -1,3 +1,4 @@ +from ...util import FQ from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -8,17 +9,16 @@ def push(instruction: Instruction): num_additional_pushed = num_pushed - 1 value = instruction.stack_push() - value_le_bytes = instruction.rlc_to_le_bytes(value) selectors = instruction.continuous_selectors(num_additional_pushed, 31) for idx in range(32): index = instruction.curr.program_counter + num_pushed - idx - if idx == 0 or selectors[idx - 1]: + if idx == 0 or selectors[idx - 1] == 1: instruction.constrain_equal( - value_le_bytes[idx], instruction.opcode_lookup_at(index, False) + FQ(value.le_bytes[idx]), instruction.opcode_lookup_at(index, False) ) else: - instruction.constrain_zero(value_le_bytes[idx]) + instruction.constrain_zero(FQ(value.le_bytes[idx])) instruction.step_state_transition_in_same_context( opcode, diff --git a/src/zkevm_specs/evm/execution/slt_sgt.py b/src/zkevm_specs/evm/execution/slt_sgt.py index 752ef1bc1..9aad5a79d 100644 --- a/src/zkevm_specs/evm/execution/slt_sgt.py +++ b/src/zkevm_specs/evm/execution/slt_sgt.py @@ -1,5 +1,4 @@ -from typing import Sequence, Tuple - +from ...util import FQ from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -14,13 +13,13 @@ def scmp(instruction: Instruction): c = instruction.stack_push() # swap a and b if the opcode is SGT - aa = b if is_sgt else a - bb = a if is_sgt else b + aa = b if is_sgt == 1 else a + bb = a if is_sgt == 1 else b # decode RLC to bytes for a and b - a8s = instruction.rlc_to_le_bytes(aa) - b8s = instruction.rlc_to_le_bytes(bb) - c8s = instruction.rlc_to_le_bytes(c) + a8s = aa.le_bytes + b8s = bb.le_bytes + c8s = c.le_bytes a_lo = instruction.bytes_to_fq(a8s[:16]) a_hi = instruction.bytes_to_fq(a8s[16:]) @@ -29,17 +28,19 @@ def scmp(instruction: Instruction): assert c8s[31] == 0 cc = instruction.bytes_to_fq(c8s[:31]) - a_lt_b_lo, a_eq_b_lo = instruction.compare(a_lo, b_lo, 16) + a_lt_b_lo, _ = instruction.compare(a_lo, b_lo, 16) a_lt_b_hi, a_eq_b_hi = instruction.compare(a_hi, b_hi, 16) - a_lt_b = instruction.select(a_lt_b_hi, 1, instruction.select(a_eq_b_hi * a_lt_b_lo, 1, 0)) + a_lt_b = instruction.select( + a_lt_b_hi, FQ(1), instruction.select(a_eq_b_hi * a_lt_b_lo, FQ(1), FQ(0)) + ) # a < 0 and b >= 0 => a < b == true if a8s[31] >= 128 and b8s[31] < 128: - instruction.constrain_equal(cc, 1) + instruction.constrain_equal(cc, FQ(1)) # b < 0 and a >= 0 => a < b == false elif b8s[31] >= 128 and a8s[31] < 128: - instruction.constrain_equal(cc, 0) + instruction.constrain_equal(cc, FQ(0)) # (a < 0 and b < 0) or (a >= 0 and b >= 0) else: instruction.constrain_equal(cc, a_lt_b) diff --git a/src/zkevm_specs/evm/execution/storage.py b/src/zkevm_specs/evm/execution/storage.py index ae738ac87..32ea5317b 100644 --- a/src/zkevm_specs/evm/execution/storage.py +++ b/src/zkevm_specs/evm/execution/storage.py @@ -1,7 +1,5 @@ -from ..instruction import Instruction, Transition -from ..opcode import Opcode -from ..table import CallContextFieldTag, TxContextFieldTag -from ...util.param import ( +from ...util import ( + FQ, COLD_SLOAD_COST, WARM_STORAGE_READ_COST, SLOAD_GAS, @@ -9,6 +7,9 @@ SSTORE_RESET_GAS, SSTORE_CLEARS_SCHEDULE, ) +from ..instruction import Instruction, Transition +from ..opcode import Opcode +from ..table import CallContextFieldTag def sload(instruction: Instruction): @@ -16,10 +17,7 @@ def sload(instruction: Instruction): instruction.constrain_equal(opcode, Opcode.SLOAD) tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) - rw_counter_end_of_reversion = instruction.call_context_lookup( - CallContextFieldTag.RwCounterEndOfReversion - ) - is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent) + reversion_info = instruction.reversion_info() callee_address = instruction.call_context_lookup(CallContextFieldTag.CalleeAddress) storage_key = instruction.stack_pop() @@ -29,11 +27,14 @@ def sload(instruction: Instruction): instruction.stack_push(), ) - is_warm_new, is_warm = instruction.add_account_storage_to_access_list_with_reversion( - tx_id, callee_address, storage_key, is_persistent, rw_counter_end_of_reversion + is_cold = instruction.add_account_storage_to_access_list( + tx_id, + callee_address, + storage_key, + reversion_info, ) - dynamic_gas_cost = instruction.select(is_warm, WARM_STORAGE_READ_COST, COLD_SLOAD_COST) + dynamic_gas_cost = instruction.select(is_cold, FQ(COLD_SLOAD_COST), FQ(WARM_STORAGE_READ_COST)) instruction.step_state_transition_in_same_context( opcode, @@ -50,26 +51,27 @@ def sstore(instruction: Instruction): instruction.constrain_equal(opcode, Opcode.SSTORE) tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) - rw_counter_end_of_reversion = instruction.call_context_lookup( - CallContextFieldTag.RwCounterEndOfReversion - ) - is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent) + reversion_info = instruction.reversion_info() callee_address = instruction.call_context_lookup(CallContextFieldTag.CalleeAddress) storage_key = instruction.stack_pop() storage_value = instruction.stack_pop() - value, value_prev, original_value = instruction.account_storage_write_with_reversion( - callee_address, storage_key, tx_id, is_persistent, rw_counter_end_of_reversion + value, value_prev, original_value = instruction.account_storage_write( + callee_address, + storage_key, + tx_id, + reversion_info, ) instruction.constrain_equal(storage_value, value) - is_warm_new, is_warm = instruction.add_account_storage_to_access_list_with_reversion( - tx_id, callee_address, storage_key, is_persistent, rw_counter_end_of_reversion + is_cold = instruction.add_account_storage_to_access_list( + tx_id, + callee_address, + storage_key, + reversion_info, ) - gas_refund, gas_refund_prev = instruction.tx_refund_write_with_reversion( - tx_id, is_persistent, rw_counter_end_of_reversion - ) + gas_refund, gas_refund_prev = instruction.tx_refund_write(tx_id, reversion_info) # original_value, value_prev, value all are different; original_value!=0 nz_allne_case_refund = instruction.select( @@ -113,17 +115,18 @@ def sstore(instruction: Instruction): instruction.constrain_equal(gas_refund, gas_refund_new) + eq_prev = instruction.is_equal(value_prev, value) + prev_ne_original = 1 - instruction.is_equal(value_prev, original_value) warm_case_gas = instruction.select( - instruction.is_equal(value_prev, value) - or (not instruction.is_equal(original_value, value_prev)), - SLOAD_GAS, + eq_prev + prev_ne_original - eq_prev * prev_ne_original, + FQ(SLOAD_GAS), instruction.select( instruction.is_zero(original_value), - SSTORE_SET_GAS, - SSTORE_RESET_GAS, + FQ(SSTORE_SET_GAS), + FQ(SSTORE_RESET_GAS), ), ) - dynamic_gas_cost = instruction.select(is_warm, warm_case_gas, warm_case_gas + COLD_SLOAD_COST) + dynamic_gas_cost = instruction.select(is_cold, warm_case_gas + COLD_SLOAD_COST, warm_case_gas) instruction.step_state_transition_in_same_context( opcode, diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index f57b80769..d6b41b97b 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -1,12 +1,13 @@ from enum import IntEnum, auto from typing import Sequence, Tuple, Union +from ..util import FQ from .opcode import ( Opcode, invalid_opcodes, - state_write_opcodes, - stack_underflow_pairs, stack_overflow_pairs, + stack_underflow_pairs, + state_write_opcodes, ) @@ -140,6 +141,9 @@ class ExecutionState(IntEnum): # TODO: Precompile success and error cases + def expr(self) -> FQ: + return FQ(self) + def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: if self == ExecutionState.STOP: return [Opcode.STOP] @@ -378,17 +382,17 @@ def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: return state_write_opcodes() return [] - def halts(self): + def halts(self) -> bool: return self.halts_in_success() or self.halts_in_exception() or self == ExecutionState.REVERT - def halts_in_success(self): + def halts_in_success(self) -> bool: return self in [ ExecutionState.STOP, ExecutionState.RETURN, ExecutionState.SELFDESTRUCT, ] - def halts_in_exception(self): + def halts_in_exception(self) -> bool: return self in [ ExecutionState.ErrorInvalidOpcode, ExecutionState.ErrorStack, diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 7f40870b0..2374da577 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -1,13 +1,14 @@ from __future__ import annotations from enum import IntEnum, auto -from typing import Optional, Sequence, Tuple, Union, Mapping +from typing import Optional, Sequence, Tuple, Union from ..util import ( - Array4, - Array10, FQ, IntOrFQ, RLC, + Expression, + ExpressionImpl, + cast_expr, MAX_N_BYTES, N_BYTES_MEMORY_ADDRESS, N_BYTES_MEMORY_SIZE, @@ -23,6 +24,8 @@ AccountFieldTag, BlockContextFieldTag, CallContextFieldTag, + FixedTableRow, + RWTableRow, Tables, FixedTableTag, TxContextFieldTag, @@ -44,27 +47,51 @@ class TransitionKind(IntEnum): class Transition: kind: TransitionKind - value: Optional[int] + value: Union[int, FQ, RLC] - def __init__(self, kind: TransitionKind, value: Optional[int] = None) -> None: + def __init__(self, kind: TransitionKind, value: Union[int, FQ, RLC] = 0) -> None: self.kind = kind self.value = value + @staticmethod def same() -> Transition: return Transition(TransitionKind.Same) - def delta(delta: int): + @staticmethod + def delta(delta: Union[int, FQ, RLC]): return Transition(TransitionKind.Delta, delta) - def to(to: int): + @staticmethod + def to(to: Union[int, FQ, RLC]): return Transition(TransitionKind.To, to) +class ReversionInfo: + rw_counter_end_of_reversion: FQ + is_persistent: FQ + state_write_counter: FQ + + def __init__( + self, + rw_counter_end_of_reversion: Expression, + is_persistent: Expression, + state_write_counter: Expression, + ) -> None: + self.rw_counter_end_of_reversion = rw_counter_end_of_reversion.expr() + self.is_persistent = is_persistent.expr() + self.state_write_counter = state_write_counter.expr() + + def rw_counter(self) -> FQ: + rw_counter = self.rw_counter_end_of_reversion - self.state_write_counter + self.state_write_counter += 1 + return rw_counter + + class Instruction: randomness: FQ tables: Tables curr: StepState - next: StepState + next: Optional[StepState] # meta information is_first_step: bool @@ -74,14 +101,13 @@ class Instruction: rw_counter_offset: int = 0 program_counter_offset: int = 0 stack_pointer_offset: int = 0 - state_write_counter_offset: int = 0 def __init__( self, randomness: FQ, tables: Tables, curr: StepState, - next: StepState, + next: Optional[StepState], is_first_step: bool, is_last_step: bool, ) -> None: @@ -92,20 +118,20 @@ def __init__( self.is_first_step = is_first_step self.is_last_step = is_last_step - def constrain_zero(self, value: FQ): - assert value == 0, ConstraintUnsatFailure(f"Expected value to be 0, but got {value}") + def constrain_zero(self, value: Expression): + assert value.expr() == 0, ConstraintUnsatFailure(f"Expected value to be 0, but got {value}") - def constrain_equal(self, lhs: FQ, rhs: FQ): - assert lhs == rhs, ConstraintUnsatFailure( + def constrain_equal(self, lhs: Expression, rhs: Expression): + assert lhs.expr() == rhs.expr(), ConstraintUnsatFailure( f"Expected values to be equal, but got {lhs} and {rhs}" ) - def constrain_bool(self, num: FQ): - assert num.n in [0, 1], ConstraintUnsatFailure( + def constrain_bool(self, num: Expression): + assert num.expr() in [0, 1], ConstraintUnsatFailure( f"Expected value to be a bool, but got {num}" ) - def constrain_gas_left_not_underflow(self, gas_left: FQ): + def constrain_gas_left_not_underflow(self, gas_left: Expression): self.range_check(gas_left, N_BYTES_GAS) def constrain_execution_state_transition(self): @@ -149,20 +175,28 @@ def constrain_step_state_transition(self, **kwargs: Transition): for key, transition in kwargs.items(): curr, next = getattr(self.curr, key), getattr(self.next, key) + if isinstance(curr, int): + curr = FQ(curr) + if isinstance(next, int): + next = FQ(next) if transition.kind == TransitionKind.Same: - assert next == curr, ConstraintUnsatFailure( + assert next.expr() == curr.expr(), ConstraintUnsatFailure( f"State {key} should be same as {curr}, but got {next}" ) elif transition.kind == TransitionKind.Delta: - assert next == curr + transition.value, ConstraintUnsatFailure( - f"State {key} should transit to {curr + transition.value}, but got {next}" + if isinstance(transition.value, int): + transition.value = FQ(transition.value) + assert next.expr() == curr.expr() + transition.value.expr(), ConstraintUnsatFailure( + f"State {key} should transit to {curr} + {transition.value}, but got {next}" ) elif transition.kind == TransitionKind.To: - assert next == transition.value, ConstraintUnsatFailure( + if isinstance(transition.value, int): + transition.value = FQ(transition.value) + assert next.expr() == transition.value.expr(), ConstraintUnsatFailure( f"State {key} should transit to {transition.value}, but got {next}" ) else: - raise ValueError("unreacheable") + raise ValueError("Unreacheable") def step_state_transition_to_new_context( self, @@ -190,17 +224,17 @@ def step_state_transition_to_new_context( def step_state_transition_in_same_context( self, - opcode: int, + opcode: Expression, rw_counter: Transition = Transition.same(), program_counter: Transition = Transition.same(), stack_pointer: Transition = Transition.same(), memory_size: Transition = Transition.same(), state_write_counter: Transition = Transition.same(), - dynamic_gas_cost: int = 0, + dynamic_gas_cost: IntOrFQ = 0, ): self.responsible_opcode_lookup(opcode) - gas_cost = Opcode(opcode).constant_gas_cost() + dynamic_gas_cost + gas_cost = FQ(Opcode(opcode.expr().n).constant_gas_cost() + dynamic_gas_cost) self.constrain_gas_left_not_underflow(self.curr.gas_left - gas_cost) self.constrain_step_state_transition( @@ -217,50 +251,47 @@ def step_state_transition_in_same_context( code_source=Transition.same(), ) - def sum(self, values: Sequence[FQ]) -> FQ: - return sum(values) + def sum(self, values: Sequence[IntOrFQ]) -> FQ: + return FQ(sum(values)) - def is_zero(self, value: Union[FQ, RLC]) -> bool: - return value == 0 + def is_zero(self, value: Expression) -> FQ: + return FQ(value.expr() == 0) - def is_equal(self, lhs: Union[FQ, RLC], rhs: Union[FQ, RLC]) -> bool: - if isinstance(lhs, RLC): - lhs = lhs.value - if isinstance(rhs, RLC): - rhs = rhs.value - return self.is_zero(lhs - rhs) + def is_equal(self, lhs: Expression, rhs: Expression) -> FQ: + return self.is_zero(lhs.expr() - rhs.expr()) - def continuous_selectors(self, t: IntOrFQ, n: int) -> Sequence[bool]: - t = t.n if isinstance(t, FQ) else t - return [i < t for i in range(n)] + def continuous_selectors(self, value: Expression, n: int) -> Sequence[FQ]: + return [FQ(i < value.expr().n) for i in range(n)] - def select(self, condition: bool, when_true: FQ, when_false: FQ) -> FQ: - return when_true if condition else when_false + def select( + self, condition: FQ, when_true: ExpressionImpl, when_false: ExpressionImpl + ) -> ExpressionImpl: + assert condition in [0, 1], "Condition of select should be a checked bool" + return when_true if condition == 1 else when_false - def pair_select(self, value: FQ, lhs: FQ, rhs: FQ) -> Tuple[bool, bool]: - return value == lhs, value == rhs + def pair_select(self, value: Expression, lhs: Expression, rhs: Expression) -> Tuple[FQ, FQ]: + return FQ(value.expr() == lhs.expr()), FQ(value.expr() == rhs.expr()) def constant_divmod( - self, numerator: IntOrFQ, denominator: IntOrFQ, n_bytes: int + self, numerator: Expression, denominator: Expression, n_bytes: int ) -> Tuple[FQ, FQ]: - quotient, remainder = divmod(FQ(numerator).n, FQ(denominator).n) - quotient, remainder = FQ(quotient), FQ(remainder) - self.range_check(quotient, n_bytes) - return quotient, remainder + quotient, remainder = divmod(numerator.expr().n, denominator.expr().n) + self.range_check(FQ(quotient), n_bytes) + return FQ(quotient), FQ(remainder) - def compare(self, lhs: FQ, rhs: FQ, n_bytes: int) -> Tuple[bool, bool]: + def compare(self, lhs: Expression, rhs: Expression, n_bytes: int) -> Tuple[FQ, FQ]: assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" - assert lhs.n < 256**n_bytes, f"lhs {lhs} exceeds the range of {n_bytes} bytes" - assert rhs.n < 256**n_bytes, f"rhs {rhs} exceeds the range of {n_bytes} bytes" - return lhs.n < rhs.n, lhs.n == rhs.n + assert lhs.expr().n < 256**n_bytes, f"lhs {lhs} exceeds the range of {n_bytes} bytes" + assert rhs.expr().n < 256**n_bytes, f"rhs {rhs} exceeds the range of {n_bytes} bytes" + return FQ(lhs.expr().n < rhs.expr().n), FQ(lhs.expr().n == rhs.expr().n) - def min(self, lhs: FQ, rhs: FQ, n_bytes: int) -> FQ: + def min(self, lhs: Expression, rhs: Expression, n_bytes: int) -> FQ: lt, _ = self.compare(lhs, rhs, n_bytes) - return self.select(lt, lhs, rhs) + return cast_expr(self.select(lt, lhs, rhs), FQ) - def max(self, lhs: FQ, rhs: FQ, n_bytes: int) -> FQ: + def max(self, lhs: Expression, rhs: Expression, n_bytes: int) -> FQ: lt, _ = self.compare(lhs, rhs, n_bytes) - return self.select(lt, rhs, lhs) + return cast_expr(self.select(lt, rhs, lhs), FQ) def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, FQ]: addends_lo, addends_hi = list(zip(*map(self.word_to_lo_hi, addends))) @@ -269,11 +300,10 @@ def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, FQ]: carry_hi, sum_hi = divmod((self.sum(addends_hi) + carry_lo).n, 1 << 128) sum_bytes = sum_lo.to_bytes(16, "little") + sum_hi.to_bytes(16, "little") - carry_hi = FQ(carry_hi) - return RLC(sum_bytes, self.randomness), carry_hi + return RLC(sum_bytes, self.randomness), FQ(carry_hi) - def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: + def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, FQ]: minuend_lo, minuend_hi = self.word_to_lo_hi(minuend) subtrahend_lo, subtrahend_hi = self.word_to_lo_hi(subtrahend) @@ -284,92 +314,82 @@ def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: diff_bytes = diff_lo.n.to_bytes(16, "little") + diff_hi.n.to_bytes(16, "little") - return RLC(diff_bytes, self.randomness), borrow_hi + return RLC(diff_bytes, self.randomness), FQ(borrow_hi) - def mul_word_by_u64(self, multiplicand: RLC, multiplier: FQ) -> Tuple[RLC, FQ]: + def mul_word_by_u64(self, multiplicand: RLC, multiplier: Expression) -> Tuple[RLC, FQ]: multiplicand_lo, multiplicand_hi = self.word_to_lo_hi(multiplicand) - quotient_lo, product_lo = divmod((multiplicand_lo * multiplier).n, 1 << 128) - quotient_hi, product_hi = divmod((multiplicand_hi * multiplier + quotient_lo).n, 1 << 128) + quotient_lo, product_lo = divmod((multiplicand_lo * multiplier.expr()).n, 1 << 128) + quotient_hi, product_hi = divmod( + (multiplicand_hi * multiplier.expr() + quotient_lo).n, 1 << 128 + ) product_bytes = product_lo.to_bytes(16, "little") + product_hi.to_bytes(16, "little") - quotient_hi = FQ(quotient_hi) - - return RLC(product_bytes, self.randomness), quotient_hi - def rlc_to_le_bytes(self, rlc: RLC) -> bytes: - return rlc.le_bytes + return RLC(product_bytes, self.randomness), FQ(quotient_hi) - def rlc_to_fq_unchecked(self, rlc: RLC, n_bytes: int) -> FQ: - rlc_le_bytes = self.rlc_to_le_bytes(rlc) - return self.bytes_to_fq(rlc_le_bytes[:n_bytes]), self.is_zero( - self.sum(rlc_le_bytes[n_bytes:]) + def rlc_to_fq_unchecked(self, word: RLC, n_bytes: int) -> Tuple[FQ, FQ]: + return self.bytes_to_fq(word.le_bytes[:n_bytes]), self.is_zero( + self.sum(word.le_bytes[n_bytes:]) ) - def rlc_to_fq_exact(self, rlc: RLC, n_bytes: int) -> FQ: - rlc_le_bytes = self.rlc_to_le_bytes(rlc) + def rlc_to_fq_exact(self, word: RLC, n_bytes: int) -> FQ: + if any(word.le_bytes[n_bytes:]): + raise ConstraintUnsatFailure(f"Word {word} has too many bytes to fit {n_bytes} bytes") - if sum(rlc_le_bytes[n_bytes:]) > 0: - raise ConstraintUnsatFailure(f"Value {rlc} has too many bytes to fit {n_bytes} bytes") - - return self.bytes_to_fq(rlc_le_bytes[:n_bytes]) + return self.bytes_to_fq(word.le_bytes[:n_bytes]) def word_to_lo_hi(self, word: RLC) -> Tuple[FQ, FQ]: - word_le_bytes = self.rlc_to_le_bytes(word) - assert len(word_le_bytes) == 32, "Expected word to contain 32 bytes" - return self.bytes_to_fq(word_le_bytes[:16]), self.bytes_to_fq(word_le_bytes[16:]) - - def int_to_rlc(self, value: int, n_bytes: int) -> RLC: - return RLC(value, self.randomness, n_bytes) - - def bytes_to_rlc(self, value: bytes) -> RLC: - return RLC(value, self.randomness, len(value)) - - def bytes_to_int(self, value: bytes) -> int: - assert len(value) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" - return int.from_bytes(value, "little") + assert len(word.le_bytes) == 32, "Expected word to contain 32 bytes" + return self.bytes_to_fq(word.le_bytes[:16]), self.bytes_to_fq(word.le_bytes[16:]) def bytes_to_fq(self, value: bytes) -> FQ: assert len(value) <= MAX_N_BYTES, "Too many bytes to composite an integer in field" return FQ(int.from_bytes(value, "little")) - def range_lookup(self, value: FQ, range: int): - self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), value, 0, 0]) + def range_lookup(self, value: Expression, range: int): + self.fixed_lookup(FixedTableTag.range_table_tag(range), value) - def byte_range_lookup(self, value: FQ): - assert isinstance(value, FQ), f"Expect type FQ, but get type {type(value)}" + def byte_range_lookup(self, value: Expression): self.range_lookup(value, 256) - def range_check(self, value: FQ, n_bytes: int) -> bytes: + def range_check(self, value: Expression, n_bytes: int) -> bytes: assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" - assert isinstance(value, FQ) try: - return value.n.to_bytes(n_bytes, "little") + return value.expr().n.to_bytes(n_bytes, "little") except OverflowError: raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") - def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> Array4: - return self.tables.fixed_lookup([tag] + inputs) + def fixed_lookup( + self, + tag: FixedTableTag, + value0: Expression, + value1: Expression = None, + value2: Expression = None, + ) -> FixedTableRow: + return self.tables.fixed_lookup(FQ(tag), value0, value1, value2) - def block_context_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ.zero()) -> FQ: - return self.tables.block_lookup([tag, index])[2] + def block_context_lookup( + self, field_tag: BlockContextFieldTag, block_number: Expression = FQ(0) + ) -> Expression: + return self.tables.block_lookup(FQ(field_tag), block_number).value - def tx_context_lookup( - self, tx_id: FQ, field_tag: TxContextFieldTag, index: FQ = FQ.zero() - ) -> Union[FQ, RLC]: - return self.tables.tx_lookup([tx_id, field_tag, index])[3] + def tx_context_lookup(self, tx_id: Expression, field_tag: TxContextFieldTag) -> Expression: + return self.tables.tx_lookup(tx_id, FQ(field_tag)).value - def tx_calldata_lookup(self, tx_id: FQ, index: FQ) -> FQ: - return self.tables.tx_lookup([tx_id, TxContextFieldTag.CallData, index])[3] + def tx_calldata_lookup(self, tx_id: Expression, call_data_index: Expression) -> Expression: + return self.tables.tx_lookup(tx_id, FQ(TxContextFieldTag.CallData), call_data_index).value - def bytecode_lookup(self, bytecode_hash: RLC, index: FQ, is_code: FQ) -> FQ: - return self.tables.bytecode_lookup([bytecode_hash, index, Tables._, is_code])[2] + def bytecode_lookup( + self, bytecode_hash: Expression, index: Expression, is_code: bool + ) -> Expression: + return self.tables.bytecode_lookup(bytecode_hash, index, FQ(is_code)).byte - def tx_gas_price(self, tx_id: FQ) -> FQ: - return self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice) + def tx_gas_price(self, tx_id: Expression) -> RLC: + return cast_expr(self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice), RLC) - def responsible_opcode_lookup(self, opcode: int): - self.fixed_lookup(FixedTableTag.ResponsibleOpcode, [self.curr.execution_state, opcode]) + def responsible_opcode_lookup(self, opcode: Expression): + self.fixed_lookup(FixedTableTag.ResponsibleOpcode, FQ(self.curr.execution_state), opcode) def opcode_lookup(self, is_code: bool) -> FQ: index = self.curr.program_counter + self.program_counter_offset @@ -382,364 +402,303 @@ def opcode_lookup_at(self, index: FQ, is_code: bool) -> FQ: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.code_source, index, is_code) + return self.bytecode_lookup(self.curr.code_source, index, is_code).expr() def rw_lookup( - self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None - ) -> Array10: + self, + rw: RW, + tag: RWTableTag, + key1: Expression = None, + key2: Expression = None, + key3: Expression = None, + value: Expression = None, + value_prev: Expression = None, + aux0: Expression = None, + aux1: Expression = None, + rw_counter: Expression = None, + ) -> RWTableRow: if rw_counter is None: rw_counter = self.curr.rw_counter + self.rw_counter_offset self.rw_counter_offset += 1 - return self.tables.rw_lookup([rw_counter, rw, tag] + inputs) - - def state_write_only_persistent( - self, - tag: RWTableTag, - inputs: Sequence[int], - is_persistent: bool, - ) -> Array10: - assert tag.write_only_persistent() - - if is_persistent: - return self.rw_lookup(RW.Write, tag, inputs) - return 10 * [None] + return self.tables.rw_lookup( + rw_counter, + FQ(rw), + FQ(tag), + key1, + key2, + key3, + value, + value_prev, + aux0, + aux1, + ) - def state_write_with_reversion( + def state_write( self, tag: RWTableTag, - inputs: Sequence[int], - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Array10: + key1: Expression = None, + key2: Expression = None, + key3: Expression = None, + value: Expression = None, + value_prev: Expression = None, + aux0: Expression = None, + aux1: Expression = None, + reversion_info: ReversionInfo = None, + ) -> RWTableRow: assert tag.write_with_reversion() - row = self.rw_lookup(RW.Write, tag, inputs) - - if state_write_counter is None: - state_write_counter = self.curr.state_write_counter + self.state_write_counter_offset - self.state_write_counter_offset += 1 - - rw_counter = rw_counter_end_of_reversion - state_write_counter - - if not is_persistent: - # Swap value and value_prev - inputs = list(row[3:]) - inputs[-3], inputs[-4] = inputs[-4], inputs[-3] - self.rw_lookup(RW.Write, tag, inputs, rw_counter=rw_counter) + row = self.rw_lookup(RW.Write, tag, key1, key2, key3, value, value_prev, aux0, aux1) + + if reversion_info is not None and reversion_info.is_persistent == 0: + self.tables.rw_lookup( + rw_counter=reversion_info.rw_counter(), + rw=FQ(RW.Write), + tag=FQ(tag), + key1=row.key1, + key2=row.key2, + key3=row.key3, + # Swap value and value_prev + value=row.value_prev, + value_prev=row.value, + aux0=row.aux0, + aux1=row.aux1, + ) return row def call_context_lookup( - self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[int] = None - ) -> FQ: + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Expression = None + ) -> Expression: if call_id is None: call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag])[-4] + return self.rw_lookup(rw, RWTableTag.CallContext, call_id, FQ(field_tag)).value - def stack_pop(self) -> Union[FQ, RLC]: + def reversion_info(self, call_id: Expression = None) -> ReversionInfo: + rw_counter_end_of_reversion = self.call_context_lookup( + CallContextFieldTag.RwCounterEndOfReversion, call_id=call_id + ) + is_persistent = self.call_context_lookup(CallContextFieldTag.IsPersistent, call_id=call_id) + return ReversionInfo( + rw_counter_end_of_reversion, is_persistent, self.curr.state_write_counter + ) + + def stack_pop(self) -> RLC: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 - return self.stack_lookup(False, stack_pointer_offset) + return self.stack_lookup(RW.Read, FQ(stack_pointer_offset)) - def stack_push(self) -> Union[FQ, RLC]: + def stack_push(self) -> RLC: self.stack_pointer_offset -= 1 - return self.stack_lookup(True, self.stack_pointer_offset) + return self.stack_lookup(RW.Write, FQ(self.stack_pointer_offset)) - def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[FQ, RLC]: + def stack_lookup(self, rw: RW, stack_pointer_offset: Expression) -> RLC: stack_pointer = self.curr.stack_pointer + stack_pointer_offset - return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer])[-4] + return cast_expr( + self.rw_lookup(rw, RWTableTag.Stack, self.curr.call_id, stack_pointer).value, RLC + ) - def memory_write(self, memory_address: int, call_id: Optional[int] = None) -> FQ: + def memory_write(self, memory_address: Expression, call_id: Expression = None) -> FQ: return self.memory_lookup(RW.Write, memory_address, call_id) - def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = None) -> FQ: + def memory_lookup(self, rw: RW, memory_address: Expression, call_id: Expression = None) -> FQ: if call_id is None: call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[-4] + return cast_expr(self.rw_lookup(rw, RWTableTag.Memory, call_id, memory_address).value, FQ) - def tx_refund_read(self, tx_id) -> FQ: - row = self.rw_lookup(RW.Read, RWTableTag.TxRefund, [tx_id]) - return row[-4] + def tx_refund_read(self, tx_id: Expression) -> FQ: + return cast_expr(self.rw_lookup(RW.Read, RWTableTag.TxRefund, tx_id).value, FQ) - def tx_refund_write_with_reversion( + def tx_refund_write( self, - tx_id: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + tx_id: Expression, + reversion_info: ReversionInfo = None, ) -> Tuple[FQ, FQ]: - row = self.state_write_with_reversion( + row = self.state_write( RWTableTag.TxRefund, - [tx_id], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + tx_id, + reversion_info=reversion_info, ) - return row[-4], row[-3] - - def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> FQ: - row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) - return row[-4] - - def account_write( - self, - account_address: int, - account_field_tag: AccountFieldTag, - ) -> Tuple[FQ, FQ]: - row = self.rw_lookup( - RW.Write, - RWTableTag.Account, - [account_address, account_field_tag], + return cast_expr(row.value, FQ), cast_expr(row.value_prev, FQ) + + def account_read(self, account_address: Expression, account_field_tag: AccountFieldTag) -> RLC: + return cast_expr( + self.rw_lookup( + RW.Read, RWTableTag.Account, account_address, FQ(account_field_tag) + ).value, + RLC, ) - return row[-4], row[-3] - def account_write_with_reversion( + def account_write( self, - account_address: int, + account_address: Expression, account_field_tag: AccountFieldTag, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: - row = self.state_write_with_reversion( + reversion_info: ReversionInfo = None, + ) -> Tuple[Expression, Expression]: + row = self.state_write( RWTableTag.Account, - [account_address, account_field_tag], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + account_address, + FQ(account_field_tag), + reversion_info=reversion_info, ) - return row[-4], row[-3] - - def add_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: - balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([balance_prev, *values]) - self.constrain_equal(balance, result) - self.constrain_zero(carry) - return balance, balance_prev + return row.value, row.value_prev - def add_balance_with_reversion( + def add_balance( self, - account_address: int, - values: Sequence[int], - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: - balance, balance_prev = self.account_write_with_reversion( - account_address, - AccountFieldTag.Balance, - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + account_address: Expression, + values: Sequence[RLC], + reversion_info: ReversionInfo = None, + ) -> Tuple[RLC, RLC]: + value, value_prev = self.account_write( + account_address, AccountFieldTag.Balance, reversion_info ) + balance, balance_prev = cast_expr(value, RLC), cast_expr(value_prev, RLC) result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result) self.constrain_zero(carry) return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: - balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([balance, *values]) - self.constrain_equal(balance_prev, result) - self.constrain_zero(carry) - return balance, balance_prev - - def sub_balance_with_reversion( + def sub_balance( self, - account_address: int, - values: Sequence[int], - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: - balance, balance_prev = self.account_write_with_reversion( - account_address, - AccountFieldTag.Balance, - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + account_address: Expression, + values: Sequence[RLC], + reversion_info: ReversionInfo = None, + ) -> Tuple[RLC, RLC]: + value, value_prev = self.account_write( + account_address, AccountFieldTag.Balance, reversion_info ) + balance, balance_prev = cast_expr(value, RLC), cast_expr(value_prev, RLC) result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result) self.constrain_zero(carry) return balance, balance_prev - def account_storage_read(self, account_address: int, storage_key: int, tx_id: int) -> FQ: + def account_storage_read( + self, account_address: Expression, storage_key: Expression, tx_id: Expression + ) -> RLC: row = self.rw_lookup( RW.Read, RWTableTag.AccountStorage, - [account_address, storage_key, 0, Tables._, Tables._, tx_id], + account_address, + storage_key, + aux0=tx_id, ) - return row[-4] + return cast_expr(row.value, RLC) - def account_storage_write_with_reversion( + def account_storage_write( self, - account_address: int, - storage_key: int, - tx_id: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ, FQ]: - row = self.state_write_with_reversion( + account_address: Expression, + storage_key: Expression, + tx_id: Expression, + reversion_info: ReversionInfo = None, + ) -> Tuple[RLC, RLC, RLC]: + row = self.state_write( RWTableTag.AccountStorage, - [account_address, storage_key, 0, Tables._, Tables._, tx_id], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + account_address, + storage_key, + aux0=tx_id, + reversion_info=reversion_info, ) - return row[-4], row[-3], row[-1] + return cast_expr(row.value, RLC), cast_expr(row.value_prev, RLC), cast_expr(row.aux1, RLC) def add_account_to_access_list( - self, - tx_id: int, - account_address: int, - ) -> FQ: - row = self.rw_lookup( - RW.Write, - RWTableTag.TxAccessListAccount, - [tx_id, account_address, 0, 1], - ) - return row[-4] - row[-3] - - def add_account_to_access_list_with_reversion( - self, - tx_id: int, - account_address: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + self, tx_id: Expression, account_address: Expression, reversion_info: ReversionInfo = None ) -> FQ: - row = self.state_write_with_reversion( + row = self.state_write( RWTableTag.TxAccessListAccount, - [tx_id, account_address, 0, 1], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + tx_id, + account_address, + value=FQ(1), + reversion_info=reversion_info, ) - return row[-4] - row[-3] + return row.value.expr() - row.value_prev.expr() def add_account_storage_to_access_list( self, - tx_id: int, - account_address: int, - storage_key: int, - ) -> Tuple[bool, bool]: - row = self.rw_lookup( - RW.Write, - RWTableTag.TxAccessListAccountStorage, - [tx_id, account_address, storage_key, 1], - ) - return row[-4] == 1, row[-3] == 1 - - def add_account_storage_to_access_list_with_reversion( - self, - tx_id: int, - account_address: int, - storage_key: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[bool, bool]: - row = self.state_write_with_reversion( + tx_id: Expression, + account_address: Expression, + storage_key: Expression, + reversion_info: ReversionInfo = None, + ) -> FQ: + row = self.state_write( RWTableTag.TxAccessListAccountStorage, - [tx_id, account_address, storage_key, 1], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, + tx_id, + account_address, + storage_key, + value=FQ(1), + reversion_info=reversion_info, ) - return row[-4] == 1, row[-3] == 1 + return row.value.expr() - row.value_prev.expr() def transfer_with_gas_fee( self, - sender_address: int, - receiver_address: int, - value: int, - gas_fee: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: - sender_balance_pair = self.sub_balance_with_reversion( - sender_address, - [value, gas_fee], - is_persistent, - rw_counter_end_of_reversion, - ) - receiver_balance_pair = self.add_balance_with_reversion( - receiver_address, - [value], - is_persistent, - rw_counter_end_of_reversion, - ) + sender_address: Expression, + receiver_address: Expression, + value: RLC, + gas_fee: RLC, + reversion_info: ReversionInfo = None, + ) -> Tuple[Tuple[RLC, RLC], Tuple[RLC, RLC]]: + sender_balance_pair = self.sub_balance(sender_address, [value, gas_fee], reversion_info) + receiver_balance_pair = self.add_balance(receiver_address, [value], reversion_info) return sender_balance_pair, receiver_balance_pair def transfer( self, - sender_address: int, - receiver_address: int, - value: int, - is_persistent: bool, - rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: - sender_balance_pair = self.sub_balance_with_reversion( - sender_address, - [value], - is_persistent, - rw_counter_end_of_reversion, - state_write_counter, - ) - receiver_balance_pair = self.add_balance_with_reversion( - receiver_address, - [value], - is_persistent, - rw_counter_end_of_reversion, - None if state_write_counter is None else state_write_counter + 1, - ) + sender_address: Expression, + receiver_address: Expression, + value: RLC, + reversion_info: ReversionInfo = None, + ) -> Tuple[Tuple[RLC, RLC], Tuple[RLC, RLC]]: + sender_balance_pair = self.sub_balance(sender_address, [value], reversion_info) + receiver_balance_pair = self.add_balance(receiver_address, [value], reversion_info) return sender_balance_pair, receiver_balance_pair - def memory_offset_and_length(self, offset: RLC, length: RLC) -> Tuple[FQ, FQ]: - length = self.rlc_to_fq_exact(length, N_BYTES_MEMORY_SIZE) - if self.is_zero(length): - return FQ.zero(), FQ.zero() - offset = self.rlc_to_fq_exact(offset, N_BYTES_MEMORY_ADDRESS) + def memory_offset_and_length(self, offset_word: RLC, length_word: RLC) -> Tuple[FQ, FQ]: + length = self.rlc_to_fq_exact(length_word, N_BYTES_MEMORY_ADDRESS) + if self.is_zero(length) == 1: + return FQ(0), FQ(0) + offset = self.rlc_to_fq_exact(offset_word, N_BYTES_MEMORY_ADDRESS) return offset, length - def memory_gas_cost(self, memory_size: FQ) -> FQ: + def memory_gas_cost(self, memory_size: Expression) -> FQ: quadratic_cost, _ = self.constant_divmod( - memory_size * memory_size, MEMORY_EXPANSION_QUAD_DENOMINATOR, N_BYTES_GAS + memory_size.expr() * memory_size.expr(), + FQ(MEMORY_EXPANSION_QUAD_DENOMINATOR), + N_BYTES_GAS, ) - linear_cost = MEMORY_EXPANSION_LINEAR_COEFF * memory_size + linear_cost = memory_size.expr() * MEMORY_EXPANSION_LINEAR_COEFF return quadratic_cost + linear_cost - def memory_expansion_constant_length(self, offset: FQ, length: FQ) -> Tuple[FQ, FQ]: - memory_size, _ = self.constant_divmod(length + offset + 31, 32, N_BYTES_MEMORY_SIZE) + def memory_expansion_constant_length( + self, offset: Expression, length: Expression + ) -> Tuple[FQ, FQ]: + memory_size, _ = self.constant_divmod( + length.expr() + offset.expr() + 31, FQ(32), N_BYTES_MEMORY_SIZE + ) next_memory_size = self.max(self.curr.memory_size, memory_size, N_BYTES_MEMORY_SIZE) - memory_gas_cost = self.memory_expansion_gas_cost(self.curr.memory_size) + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) memory_gas_cost_next = self.memory_gas_cost(next_memory_size) memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost - return next_memory_size, memory_expansion_gas_cost + return cast_expr(next_memory_size, FQ), cast_expr(memory_expansion_gas_cost, FQ) def memory_expansion_dynamic_length( self, - cd_offset: FQ, - cd_length: FQ, - rd_offset: Optional[FQ] = None, - rd_length: Optional[FQ] = None, + cd_offset: Expression, + cd_length: Expression, + rd_offset: Optional[Expression] = None, + rd_length: Optional[Expression] = None, ) -> Tuple[FQ, FQ]: cd_memory_size, _ = self.constant_divmod( - cd_offset + cd_length + 31, 32, N_BYTES_MEMORY_SIZE + cd_offset.expr() + cd_length.expr() + FQ(31), FQ(32), N_BYTES_MEMORY_SIZE ) next_memory_size = self.max(self.curr.memory_size, cd_memory_size, N_BYTES_MEMORY_SIZE) - if rd_offset is not None: + if rd_offset is not None and rd_length is not None: rd_memory_size, _ = self.constant_divmod( - rd_offset + rd_length + 31, 32, N_BYTES_MEMORY_SIZE + rd_offset.expr() + rd_length.expr() + FQ(31), FQ(32), N_BYTES_MEMORY_SIZE ) next_memory_size = self.max(next_memory_size, rd_memory_size, N_BYTES_MEMORY_SIZE) @@ -747,10 +706,12 @@ def memory_expansion_dynamic_length( memory_gas_cost_next = self.memory_gas_cost(next_memory_size) memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost - return next_memory_size, memory_expansion_gas_cost + return cast_expr(next_memory_size, FQ), cast_expr(memory_expansion_gas_cost, FQ) - def memory_copier_gas_cost(self, length: FQ, memory_expansion_gas_cost: FQ) -> FQ: - word_size, _ = self.constant_divmod(length + 31, 32, N_BYTES_MEMORY_SIZE) + def memory_copier_gas_cost( + self, length: Expression, memory_expansion_gas_cost: Expression + ) -> FQ: + word_size, _ = self.constant_divmod(length + FQ(31), FQ(32), N_BYTES_MEMORY_SIZE) gas_cost = word_size * GAS_COST_COPY + memory_expansion_gas_cost self.range_check(gas_cost, N_BYTES_GAS) return gas_cost diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index 8a9e6f4c7..8d9fa1131 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -15,12 +15,8 @@ def verify_steps( begin_with_first_step: bool = False, end_with_last_step: bool = False, ): - # For the last step, the next step is meaningless - if end_with_last_step: - steps += [None] - - for idx in range(len(steps) - 1): - curr, next = steps[idx], steps[idx + 1] + for idx in range(len(steps) - 1 + end_with_last_step): + curr, next = steps[idx], None if len(steps) == idx + 1 else steps[idx + 1] verify_step( Instruction( diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 6b7641aa7..14548905f 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -1,6 +1,7 @@ from enum import IntEnum -from typing import Final, Dict, List, Tuple +from typing import Final, Dict, Tuple, List +from ..util import FQ from ..util.param import * @@ -148,6 +149,9 @@ class Opcode(IntEnum): REVERT = 0xFD SELFDESTRUCT = 0xFF + def expr(self) -> FQ: + return FQ(self) + def hex(self) -> str: return "{:02x}".format(self) diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index 28bcf2271..27407b7d0 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any from .execution_state import ExecutionState from ..util import FQ, RLC @@ -40,7 +40,7 @@ class StepState: memory_size: FQ state_write_counter: FQ - # Auxilary witness data needed by gadgets + # Auxiliary witness data needed by gadgets aux_data: Any def __init__( @@ -50,7 +50,7 @@ def __init__( call_id: int = 0, is_root: bool = False, is_create: bool = False, - code_source: int = 0, + code_source: RLC = RLC(0), program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index cd7e42867..3a479faaa 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,17 +1,13 @@ from __future__ import annotations -from typing import Sequence, Set, Tuple +from typing import Any, List, Mapping, Optional, Sequence, Set, Type, TypeVar, Union from enum import IntEnum, auto from itertools import chain, product +from dataclasses import dataclass, field, fields -from ..util import FQ, RLC, Array3, Array4, Array10 +from ..util import Expression, FQ from .execution_state import ExecutionState -class Placeholder: - def __eq__(self, _) -> bool: - return True - - class FixedTableTag(IntEnum): """ Tag for FixedTable lookup, where the FixedTable is a prebuilt fixed-column @@ -30,30 +26,39 @@ class FixedTableTag(IntEnum): BitwiseXor = auto() # lhs, rhs, lhs ^ rhs, 0 ResponsibleOpcode = auto() # execution_state, opcode, aux - def table_assignments(self) -> Sequence[Array4]: + def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range16: - return [(self, i, 0, 0) for i in range(16)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(16)] elif self == FixedTableTag.Range32: - return [(self, i, 0, 0) for i in range(32)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(32)] elif self == FixedTableTag.Range64: - return [(self, i, 0, 0) for i in range(64)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(64)] elif self == FixedTableTag.Range256: - return [(self, i, 0, 0) for i in range(256)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(256)] elif self == FixedTableTag.Range512: - return [(self, i, 0, 0) for i in range(512)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(512)] elif self == FixedTableTag.Range1024: - return [(self, i, 0, 0) for i in range(1024)] + return [FixedTableRow(FQ(self), FQ(i), FQ(0), FQ(0)) for i in range(1024)] elif self == FixedTableTag.SignByte: - return [(self, i, (i >> 7) * 0xFF, 0) for i in range(256)] + return [FixedTableRow(FQ(self), FQ(i), FQ((i >> 7) * 0xFF), FQ(0)) for i in range(256)] elif self == FixedTableTag.BitwiseAnd: - return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self), FQ(lhs), FQ(rhs), FQ(lhs & rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseOr: - return [(self, lhs, rhs, lhs | rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self), FQ(lhs), FQ(rhs), FQ(lhs | rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseXor: - return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self), FQ(lhs), FQ(rhs), FQ(lhs ^ rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.ResponsibleOpcode: return [ - (self, execution_state, opcode, aux) + FixedTableRow(FQ(self), FQ(execution_state), FQ(opcode), FQ(aux)) for execution_state in list(ExecutionState) for opcode, aux in map( lambda pair: pair if isinstance(pair, tuple) else (pair, 0), @@ -119,9 +124,9 @@ class TxContextFieldTag(IntEnum): CallData = auto() -class RW: - Read = False - Write = True +class RW(IntEnum): + Read = 0 + Write = 1 class RWTableTag(IntEnum): @@ -212,116 +217,198 @@ class CallContextFieldTag(IntEnum): StateWriteCounter = auto() +class WrongQueryKey(Exception): + def __init__(self, table_name: str, diff: Set[str]) -> None: + self.message = f"Lookup {table_name} with invalid keys {diff}" + + class LookupUnsatFailure(Exception): - def __init__(self, table_name: str, inputs: Tuple[int, ...]) -> None: + def __init__(self, table_name: str, inputs: Any) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is unsatisfied on inputs {inputs}" class LookupAmbiguousFailure(Exception): - def __init__( - self, table_name: str, inputs: Tuple[int, ...], matched_rows: Sequence[Tuple[int, ...]] - ) -> None: + def __init__(self, table_name: str, inputs: Any, matched_rows: Sequence[Any]) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is ambiguous on inputs {inputs}, ${len(matched_rows)} matched rows found: {matched_rows}" +class TableRow: + @classmethod + def validate_query(cls, table_name: str, query: Mapping[str, Any]): + names = set([field.name for field in fields(cls)]) + queried = set(query.keys()) + if not queried.issubset(names): + raise WrongQueryKey(table_name, queried - names) + + def match(self, query: Mapping[str, Expression]) -> bool: + return all([value.expr() == getattr(self, key).expr() for key, value in query.items()]) + + +@dataclass(frozen=True) +class FixedTableRow(TableRow): + tag: Expression + value0: Expression + value1: Expression = field(default=FQ(0)) + value2: Expression = field(default=FQ(0)) + + +@dataclass(frozen=True) +class BlockTableRow(TableRow): + field_tag: Expression + # meaningful only for HistoryHash, will be zero for other tags + block_number_or_zero: Expression + value: Expression + + +@dataclass(frozen=True) +class TxTableRow(TableRow): + tx_id: Expression + field_tag: Expression + # meaningful only for CallData, will be zero for other tags + call_data_index_or_zero: Expression + value: Expression + + +@dataclass(frozen=True) +class BytecodeTableRow(TableRow): + bytecode_hash: Expression + index: Expression + byte: Expression + is_code: Expression + + +@dataclass(frozen=True) +class RWTableRow(TableRow): + rw_counter: Expression + rw: Expression + key0: Expression # RWTableTag + key1: Expression = field(default=FQ(0)) + key2: Expression = field(default=FQ(0)) + key3: Expression = field(default=FQ(0)) + value: Expression = field(default=FQ(0)) + value_prev: Expression = field(default=FQ(0)) + aux0: Expression = field(default=FQ(0)) + aux1: Expression = field(default=FQ(0)) + + class Tables: """ A collection of lookup tables used in EVM circuit. """ - _: Placeholder = Placeholder() - - # Each row in FixedTable contains: - # - tag - # - value1 - # - value2 - # - value3 - fixed_table: Set[Array4] = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) - - # Each row in BlockTable contains: - # - tag - # - block_number_or_zero (meaningful only for HistoryHash, will be zero for other tags) - # - value - block_table: Set[Array3] - - # Each row in TxTable contains: - # - tx_id - # - tag - # - call_data_index_or_zero (meaningful only for CallData, will be zero for other tags) - # - value - tx_table: Set[Array4] - - # Each row in BytecodeTable contains: - # - bytecode_hash - # - index - # - byte - # - is_code - bytecode_table: Set[Array4] - - # Each row in RWTable contains: - # - rw_counter - # - is_write - # - key0 (tag) - # - key1 - # - key2 - # - key3 - # - value - # - value_prev - # - aux0 - # - aux1 - rw_table: Set[Array10] + fixed_table = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) + block_table: Set[BlockTableRow] + tx_table: Set[TxTableRow] + bytecode_table: Set[BytecodeTableRow] + rw_table: Set[RWTableRow] def __init__( self, - block_table: Set[Array3], - tx_table: Set[Array4], - bytecode_table: Set[Array4], - rw_table: Set[Array10], + block_table: Set[BlockTableRow], + tx_table: Set[TxTableRow], + bytecode_table: Set[BytecodeTableRow], + rw_table: Union[Set[Sequence[Expression]], Set[RWTableRow]], ) -> None: self.block_table = block_table self.tx_table = tx_table self.bytecode_table = bytecode_table - self.rw_table = rw_table - - def fixed_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("fixed_table", self.fixed_table, inputs) - - def block_lookup(self, inputs: Sequence[int]) -> Array3: - assert len(inputs) <= 3 - return _lookup("block_table", self.block_table, inputs) - - def tx_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("tx_table", self.tx_table, inputs) + self.rw_table = set( + row if isinstance(row, RWTableRow) else RWTableRow(*row) # type: ignore # (RWTableRow input args) + for row in rw_table + ) - def bytecode_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("bytecode_table", self.bytecode_table, inputs) - - def rw_lookup(self, inputs: Sequence[int]) -> Array10: - assert len(inputs) <= 10 - return _lookup("rw_table", self.rw_table, inputs) + def fixed_lookup( + self, + tag: Expression, + value0: Expression, + value1: Expression = None, + value2: Expression = None, + ) -> FixedTableRow: + query = { + "tag": tag, + "value0": value0, + "value1": value1, + "value2": value2, + } + return _lookup(FixedTableRow, self.fixed_table, query) + + def block_lookup( + self, field_tag: Expression, block_number: Expression = FQ(0) + ) -> BlockTableRow: + query = {"field_tag": field_tag, "block_number_or_zero": block_number} + return _lookup(BlockTableRow, self.block_table, query) + + def tx_lookup( + self, tx_id: Expression, field_tag: Expression, call_data_index: Expression = FQ(0) + ) -> TxTableRow: + query = { + "tx_id": tx_id, + "field_tag": field_tag, + "call_data_index_or_zero": call_data_index, + } + return _lookup(TxTableRow, self.tx_table, query) + + def bytecode_lookup( + self, bytecode_hash: Expression, index: Expression, is_code: Expression + ) -> BytecodeTableRow: + query = { + "bytecode_hash": bytecode_hash, + "index": index, + "is_code": is_code, + } + return _lookup(BytecodeTableRow, self.bytecode_table, query) + + def rw_lookup( + self, + rw_counter: Expression, + rw: Expression, + tag: Expression, + key1: Expression = None, + key2: Expression = None, + key3: Expression = None, + value: Expression = None, + value_prev: Expression = None, + aux0: Expression = None, + aux1: Expression = None, + ) -> RWTableRow: + query = { + "rw_counter": rw_counter, + "rw": rw, + "key0": tag, + "key1": key1, + "key2": key2, + "key3": key3, + "value": value, + "value_prev": value_prev, + "aux0": aux0, + "aux1": aux1, + } + return _lookup(RWTableRow, self.rw_table, query) + + +T = TypeVar("T", bound=TableRow) def _lookup( - table_name: str, - table: Set[Tuple[int, ...]], - inputs: Sequence[int], -) -> Tuple[int, ...]: - inputs = tuple(inputs) - inputs_len = len(inputs) - matched_rows = [] - - for row in table: - if inputs == row[:inputs_len]: - matched_rows.append(row) + table_cls: Type[T], + table: Set[T], + query: Mapping[str, Optional[Expression]], +) -> T: + table_name = table_cls.__name__ + table_cls.validate_query(table_name, query) + + matched_rows = [ + row + for row in table + # Filter out None values + if row.match({key: value for key, value in query.items() if value is not None}) + ] if len(matched_rows) == 0: - raise LookupUnsatFailure(table_name, inputs) + raise LookupUnsatFailure(table_name, query) elif len(matched_rows) > 1: - raise LookupAmbiguousFailure(table_name, inputs, matched_rows) + raise LookupAmbiguousFailure(table_name, query, matched_rows) - return [v if isinstance(v, RLC) else FQ(v) for v in matched_rows[0]] + return matched_rows[0] diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 303d163e5..9d402397f 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, Iterator, NewType, Optional, Sequence +from typing import Dict, Iterator, List, NewType, Optional, Sequence, Union from functools import reduce from itertools import chain @@ -7,14 +7,26 @@ U64, U160, U256, - Array3, - Array4, + FQ, + IntOrFQ, RLC, + Expression, keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, ) -from .table import BlockContextFieldTag, TxContextFieldTag +from .table import ( + RW, + AccountFieldTag, + BlockContextFieldTag, + BlockTableRow, + BytecodeTableRow, + CallContextFieldTag, + RWTableRow, + RWTableTag, + TxContextFieldTag, + TxTableRow, +) from .opcode import get_push_size, Opcode @@ -38,12 +50,12 @@ class Block: def __init__( self, - coinbase: U160 = 0x10, - gas_limit: U64 = int(15e6), - number: U256 = 0, - timestamp: U64 = 0, - difficulty: U256 = 0, - base_fee: U256 = int(1e9), + coinbase: U160 = U160(0x10), + gas_limit: U64 = U64(int(15e6)), + number: U256 = U256(0), + timestamp: U64 = U64(0), + difficulty: U256 = U256(0), + base_fee: U256 = U256(int(1e9)), history_hashes: Sequence[U256] = [], ) -> None: assert len(history_hashes) <= min(256, number) @@ -56,16 +68,22 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, randomness: int) -> Sequence[Array3]: + def table_assignments(self, randomness: FQ) -> List[BlockTableRow]: return [ - (BlockContextFieldTag.Coinbase, 0, self.coinbase), - (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.Number, 0, RLC(self.number, randomness)), - (BlockContextFieldTag.Timestamp, 0, self.timestamp), - (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), - (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), + BlockTableRow(FQ(BlockContextFieldTag.Coinbase), FQ(0), FQ(self.coinbase)), + BlockTableRow(FQ(BlockContextFieldTag.GasLimit), FQ(0), FQ(self.gas_limit)), + BlockTableRow(FQ(BlockContextFieldTag.Number), FQ(0), RLC(self.number, randomness)), + BlockTableRow(FQ(BlockContextFieldTag.Timestamp), FQ(0), FQ(self.timestamp)), + BlockTableRow( + FQ(BlockContextFieldTag.Difficulty), FQ(0), RLC(self.difficulty, randomness) + ), + BlockTableRow(FQ(BlockContextFieldTag.BaseFee), FQ(0), RLC(self.base_fee, randomness)), ] + [ - (BlockContextFieldTag.HistoryHash, self.number - idx - 1, RLC(history_hash, randomness)) + BlockTableRow( + FQ(BlockContextFieldTag.HistoryHash), + FQ(self.number - idx - 1), + RLC(history_hash, randomness), + ) for idx, history_hash in enumerate(reversed(self.history_hashes)) ] @@ -83,12 +101,12 @@ class Transaction: def __init__( self, id: int = 1, - nonce: U64 = 0, - gas: U64 = 21000, - gas_price: U256 = int(2e9), - caller_address: U160 = 0, - callee_address: Optional[U160] = None, - value: U256 = 0, + nonce: U64 = U64(0), + gas: U64 = U64(21000), + gas_price: U256 = U256(int(2e9)), + caller_address: U160 = U160(0), + callee_address: U160 = None, + value: U256 = U256(0), call_data: bytes = bytes(), ) -> None: self.id = id @@ -114,21 +132,52 @@ def call_data_gas_cost(self) -> int: 0, ) - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: FQ) -> Iterator[TxTableRow]: return chain( [ - (self.id, TxContextFieldTag.Nonce, 0, self.nonce), - (self.id, TxContextFieldTag.Gas, 0, self.gas), - (self.id, TxContextFieldTag.GasPrice, 0, RLC(self.gas_price, randomness)), - (self.id, TxContextFieldTag.CallerAddress, 0, self.caller_address), - (self.id, TxContextFieldTag.CalleeAddress, 0, self.callee_address), - (self.id, TxContextFieldTag.IsCreate, 0, self.callee_address is None), - (self.id, TxContextFieldTag.Value, 0, RLC(self.value, randomness)), - (self.id, TxContextFieldTag.CallDataLength, 0, len(self.call_data)), - (self.id, TxContextFieldTag.CallDataGasCost, 0, self.call_data_gas_cost()), + TxTableRow(FQ(self.id), FQ(TxContextFieldTag.Nonce), FQ(0), FQ(self.nonce)), + TxTableRow(FQ(self.id), FQ(TxContextFieldTag.Gas), FQ(0), FQ(self.gas)), + TxTableRow( + FQ(self.id), + FQ(TxContextFieldTag.GasPrice), + FQ(0), + RLC(self.gas_price, randomness), + ), + TxTableRow( + FQ(self.id), FQ(TxContextFieldTag.CallerAddress), FQ(0), FQ(self.caller_address) + ), + TxTableRow( + FQ(self.id), + FQ(TxContextFieldTag.CalleeAddress), + FQ(0), + FQ(0 if self.callee_address is None else self.callee_address), + ), + TxTableRow( + FQ(self.id), + FQ(TxContextFieldTag.IsCreate), + FQ(0), + FQ(self.callee_address is None), + ), + TxTableRow( + FQ(self.id), FQ(TxContextFieldTag.Value), FQ(0), RLC(self.value, randomness) + ), + TxTableRow( + FQ(self.id), + FQ(TxContextFieldTag.CallDataLength), + FQ(0), + FQ(len(self.call_data)), + ), + TxTableRow( + FQ(self.id), + FQ(TxContextFieldTag.CallDataGasCost), + FQ(0), + FQ(self.call_data_gas_cost()), + ), ], map( - lambda item: (self.id, TxContextFieldTag.CallData, item[0], item[1]), + lambda item: TxTableRow( + FQ(self.id), FQ(TxContextFieldTag.CallData), FQ(item[0]), FQ(item[1]) + ), enumerate(self.call_data), ), ) @@ -163,13 +212,13 @@ def method(*args) -> Bytecode: return method - def push(self, value: Any, n_bytes: int = 32) -> Bytecode: + def push(self, value: Union[int, str, bytes, bytearray, RLC], n_bytes: int = 32) -> Bytecode: if isinstance(value, int): value = value.to_bytes(n_bytes, "big") elif isinstance(value, str): value = bytes.fromhex(value.lower().removeprefix("0x")) elif isinstance(value, RLC): - value = value.be_bytes() + value = bytes(reversed(value.le_bytes)) elif isinstance(value, bytes) or isinstance(value, bytearray): ... else: @@ -179,21 +228,21 @@ def push(self, value: Any, n_bytes: int = 32) -> Bytecode: opcode = Opcode.PUSH1 + n_bytes - 1 self.code.append(opcode) - self.code.extend(value.rjust(n_bytes, bytes(1))) + self.code.extend(value.rjust(n_bytes, b"\x00")) return self - def hash(self) -> int: - return int.from_bytes(keccak256(self.code), "big") + def hash(self) -> U256: + return U256(int.from_bytes(keccak256(self.code), "big")) - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: FQ) -> Iterator[BytecodeTableRow]: class BytecodeIterator: idx: int push_data_left: int - hash: RLC + hash: FQ code: bytes - def __init__(self, hash: RLC, code: bytes): + def __init__(self, hash: FQ, code: bytes): self.idx = 0 self.push_data_left = 0 self.hash = hash @@ -214,9 +263,9 @@ def __next__(self): self.idx += 1 - return (self.hash, idx, byte, is_code) + return BytecodeTableRow(self.hash, FQ(idx), FQ(byte), FQ(is_code)) - return BytecodeIterator(RLC(self.hash(), randomness), self.code) + return BytecodeIterator(RLC(self.hash(), randomness).expr(), self.code) Storage = NewType("Storage", Dict[U256, U256]) @@ -231,20 +280,269 @@ class Account: def __init__( self, - address: U160 = 0, - nonce: U256 = 0, - balance: U256 = 0, - code: Optional[Bytecode] = None, - storage: Optional[Storage] = None, + address: U160 = U160(0), + nonce: U256 = U256(0), + balance: U256 = U256(0), + code: Bytecode = None, + storage: Storage = None, ) -> None: self.address = address self.nonce = nonce self.balance = balance self.code = Bytecode() if code is None else code - self.storage = dict() if storage is None else storage + self.storage = Storage(dict()) if storage is None else storage def code_hash(self) -> U256: return self.code.hash() def storage_trie_hash(self) -> U256: raise NotImplementedError("Trie has not been implemented") + + +class RWDictionary: + rw_counter: int + rws: List[RWTableRow] + + def __init__(self, rw_counter: int) -> None: + self.rw_counter = rw_counter + self.rws = list() + + def stack_read(self, call_id: IntOrFQ, stack_pointer: IntOrFQ, value: RLC) -> RWDictionary: + return self._append( + RW.Read, RWTableTag.Stack, key1=FQ(call_id), key2=FQ(stack_pointer), value=value + ) + + def stack_write(self, call_id: IntOrFQ, stack_pointer: IntOrFQ, value: RLC) -> RWDictionary: + return self._append( + RW.Write, RWTableTag.Stack, key1=FQ(call_id), key2=FQ(stack_pointer), value=value + ) + + def memory_read(self, call_id: IntOrFQ, memory_address: IntOrFQ, byte: IntOrFQ) -> RWDictionary: + return self._append( + RW.Read, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=FQ(byte) + ) + + def memory_write( + self, call_id: IntOrFQ, memory_address: IntOrFQ, byte: IntOrFQ + ) -> RWDictionary: + return self._append( + RW.Write, RWTableTag.Memory, key1=FQ(call_id), key2=FQ(memory_address), value=FQ(byte) + ) + + def call_context_read( + self, call_id: IntOrFQ, field_tag: CallContextFieldTag, value: Union[int, FQ, RLC] + ) -> RWDictionary: + if isinstance(value, int): + value = FQ(value) + return self._append( + RW.Read, RWTableTag.CallContext, key1=FQ(call_id), key2=FQ(field_tag), value=value + ) + + def tx_refund_read(self, tx_id: IntOrFQ, refund: IntOrFQ) -> RWDictionary: + return self._append( + RW.Read, RWTableTag.TxRefund, key1=FQ(tx_id), value=FQ(refund), value_prev=FQ(refund) + ) + + def tx_refund_write( + self, + tx_id: IntOrFQ, + refund: IntOrFQ, + refund_prev: IntOrFQ, + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + return self._state_write( + RWTableTag.TxRefund, + key1=FQ(tx_id), + value=FQ(refund), + value_prev=FQ(refund_prev), + rw_counter_of_reversion=rw_counter_of_reversion, + ) + + def tx_access_list_account_write( + self, + tx_id: IntOrFQ, + account_address: IntOrFQ, + value: bool, + value_prev: bool, + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + return self._state_write( + RWTableTag.TxAccessListAccount, + key1=FQ(tx_id), + key2=FQ(account_address), + value=FQ(value), + value_prev=FQ(value_prev), + rw_counter_of_reversion=rw_counter_of_reversion, + ) + + def tx_access_list_account_storage_write( + self, + tx_id: IntOrFQ, + account_address: IntOrFQ, + storage_key: RLC, + value: bool, + value_prev: bool, + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + return self._state_write( + RWTableTag.TxAccessListAccountStorage, + key1=FQ(tx_id), + key2=FQ(account_address), + key3=storage_key, + value=FQ(value), + value_prev=FQ(value_prev), + rw_counter_of_reversion=rw_counter_of_reversion, + ) + + def account_read( + self, account_address: IntOrFQ, field_tag: AccountFieldTag, value: Union[int, FQ, RLC] + ) -> RWDictionary: + if isinstance(value, int): + value = FQ(value) + return self._append( + RW.Read, + RWTableTag.Account, + key1=FQ(account_address), + key2=FQ(field_tag), + value=value, + value_prev=value, + ) + + def account_write( + self, + account_address: IntOrFQ, + field_tag: AccountFieldTag, + value: Union[int, FQ, RLC], + value_prev: Union[int, FQ, RLC], + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + if isinstance(value, int): + value = FQ(value) + if isinstance(value_prev, int): + value_prev = FQ(value_prev) + return self._state_write( + RWTableTag.Account, + key1=FQ(account_address), + key2=FQ(field_tag), + value=value, + value_prev=value_prev, + rw_counter_of_reversion=rw_counter_of_reversion, + ) + + def account_storage_read( + self, + account_address: IntOrFQ, + storage_key: RLC, + value: RLC, + tx_id: IntOrFQ, + value_committed: RLC, + ) -> RWDictionary: + if isinstance(tx_id, int): + tx_id = FQ(tx_id) + return self._append( + RW.Read, + RWTableTag.AccountStorage, + key1=FQ(account_address), + key2=storage_key, + value=value, + value_prev=value, + aux0=tx_id, + aux1=value_committed, + ) + + def account_storage_write( + self, + account_address: IntOrFQ, + storage_key: RLC, + value: RLC, + value_prev: RLC, + tx_id: IntOrFQ, + value_committed: RLC, + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + if isinstance(tx_id, int): + tx_id = FQ(tx_id) + return self._state_write( + RWTableTag.AccountStorage, + key1=FQ(account_address), + key2=storage_key, + value=value, + value_prev=value_prev, + aux0=tx_id, + aux1=value_committed, + rw_counter_of_reversion=rw_counter_of_reversion, + ) + + def _state_write( + self, + tag: RWTableTag, + key1: Expression = FQ(0), + key2: Expression = FQ(0), + key3: Expression = FQ(0), + value: Expression = FQ(0), + value_prev: Expression = FQ(0), + aux0: Expression = FQ(0), + aux1: Expression = FQ(0), + rw_counter_of_reversion: int = None, + ) -> RWDictionary: + self._append( + RW.Write, + tag=tag, + key1=key1, + key2=key2, + key3=key3, + value=value, + value_prev=value_prev, + aux0=aux0, + aux1=aux1, + ) + + if rw_counter_of_reversion is None: + return self + else: + return self._append( + RW.Write, + tag=tag, + key1=key1, + key2=key2, + key3=key3, + value=value_prev, + value_prev=value, + aux0=aux0, + aux1=aux1, + rw_counter=rw_counter_of_reversion, + ) + + def _append( + self, + rw: RW, + tag: RWTableTag, + key1: Expression = FQ(0), + key2: Expression = FQ(0), + key3: Expression = FQ(0), + value: Expression = FQ(0), + value_prev: Expression = FQ(0), + aux0: Expression = FQ(0), + aux1: Expression = FQ(0), + rw_counter: int = None, + ) -> RWDictionary: + if rw_counter is None: + rw_counter = self.rw_counter + self.rw_counter += 1 + + self.rws.append( + RWTableRow( + FQ(rw_counter), + FQ(rw), + FQ(tag), + key1, + key2, + key3, + value, + value_prev, + aux0, + aux1, + ) + ) + + return self diff --git a/src/zkevm_specs/evm/util/memory_gadget.py b/src/zkevm_specs/evm/util/memory_gadget.py index 418c072cd..a8fa83c23 100644 --- a/src/zkevm_specs/evm/util/memory_gadget.py +++ b/src/zkevm_specs/evm/util/memory_gadget.py @@ -1,5 +1,5 @@ -from ...util import N_BYTES_MEMORY_ADDRESS, FQ -from ..instruction import Instruction, Transition +from ...util import N_BYTES_MEMORY_ADDRESS, FQ, Expression +from ..instruction import Instruction class BufferReaderGadget: @@ -24,7 +24,7 @@ def __init__( diff, inst.select(self.bound_dist_is_zero[i - 1], FQ.zero(), FQ.one()) ) - def constrain_byte(self, idx: int, byte: FQ): + def constrain_byte(self, idx: int, byte: Expression): # bytes[idx] == 0 when selectors[idx] == 0 self.instruction.constrain_zero(byte * (1 - self.selectors[idx])) # bytes[idx] == 0 when bound_dist[idx] == 0 @@ -33,8 +33,8 @@ def constrain_byte(self, idx: int, byte: FQ): def num_bytes(self) -> FQ: return FQ(sum(self.selectors)) - def has_data(self, idx: int) -> bool: + def has_data(self, idx: int) -> FQ: return self.selectors[idx] - def read_flag(self, idx: int) -> bool: - return self.selectors[idx] and not self.bound_dist_is_zero[idx] + def read_flag(self, idx: int) -> FQ: + return self.selectors[idx] * (1 - self.bound_dist_is_zero[idx]) diff --git a/src/zkevm_specs/opcode/comparator.py b/src/zkevm_specs/opcode/comparator.py index 5449f1d93..d39166008 100644 --- a/src/zkevm_specs/opcode/comparator.py +++ b/src/zkevm_specs/opcode/comparator.py @@ -34,7 +34,7 @@ def compare( assert len(result) == 16 # Before we do any comparison, the previous result is "equal" - result = result[:] + [0] + result = list(result[:]) + [Sign(0)] for i in reversed(range(0, 32, 2)): a16 = a8s[i] + 256 * a8s[i + 1] diff --git a/src/zkevm_specs/opcode/lt_gt.py b/src/zkevm_specs/opcode/lt_gt.py index 845c38832..9091ceda6 100644 --- a/src/zkevm_specs/opcode/lt_gt.py +++ b/src/zkevm_specs/opcode/lt_gt.py @@ -1,6 +1,6 @@ from typing import Sequence -from ..encoding import is_circuit_code, U8, U256, u256_to_u8s +from ..encoding import is_circuit_code, U8 def lt_circuit( @@ -34,7 +34,7 @@ def lt_circuit( # lower 16 bytes # a[15:0] + c[15:0] == carry * 256^16 + b[15:0] lhs = 0 - rhs = carry + rhs = int(carry) for i in reversed(range(16)): lhs = lhs * 256 + a8s[i] + c8s[i] rhs = rhs * 256 + b8s[i] diff --git a/src/zkevm_specs/opcode/mload_mstore.py b/src/zkevm_specs/opcode/mload_mstore.py index fc203a2f5..7e0f7e606 100644 --- a/src/zkevm_specs/opcode/mload_mstore.py +++ b/src/zkevm_specs/opcode/mload_mstore.py @@ -15,14 +15,15 @@ def address_low( address: Sequence[U8], ) -> U64: - return sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) + _sum = sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) + return U64(_sum) @is_circuit_code def address_high( address: Sequence[U8], ) -> U256: - return sum(address[NUM_ADDRESS_BYTES_USED:]) + return U256(sum(address[NUM_ADDRESS_BYTES_USED:])) @is_circuit_code @@ -38,7 +39,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code @@ -46,8 +47,8 @@ def div( value: U256, divisor: U64, ) -> Tuple[U256, U256]: - quotient = value // divisor - remainder = value % divisor + quotient = U256(value // divisor) + remainder = U256(value % divisor) return (quotient, remainder) @@ -56,7 +57,7 @@ def lt( lhs: U256, rhs: U256, ) -> U256: - return lhs < rhs + return U256(lhs < rhs) @is_circuit_code @@ -71,7 +72,7 @@ def max( def memory_size( address: U64, ) -> U64: - return (address + 31) // 32 + return U64((address + 31) // 32) @is_circuit_code diff --git a/src/zkevm_specs/state.py b/src/zkevm_specs/state.py index 7a3deb42b..e36d798d7 100644 --- a/src/zkevm_specs/state.py +++ b/src/zkevm_specs/state.py @@ -1,10 +1,10 @@ from typing import NamedTuple, Tuple, List, Sequence -from enum import IntEnum, auto +from enum import IntEnum from math import log, ceil + from .util import FQ, RLC, U160, U256 from .encoding import U8, is_circuit_code -from .evm import RW, RWTableTag -from .evm import AccountFieldTag, CallContextFieldTag +from .evm import RW, AccountFieldTag, CallContextFieldTag MAX_KEY_DIFF = 2**32 - 1 MAX_MEMORY_ADDRESS = 2**32 - 1 @@ -523,7 +523,7 @@ def op2row(op: Operation, randomness: FQ) -> Row: key2_bytes = op.key2.to_bytes(20, "little") key2_limbs = tuple([FQ(key2_bytes[i] + 2**8 * key2_bytes[i + 1]) for i in range(0, 20, 2)]) key3 = FQ(op.key3) - key4_rlc = RLC(op.key4, randomness.n) + key4_rlc = RLC(op.key4, randomness) key4 = key4_rlc.value key4_bytes = tuple([FQ(x) for x in key4_rlc.le_bytes]) value = FQ(op.value) @@ -532,7 +532,8 @@ def op2row(op: Operation, randomness: FQ) -> Row: # fmt: off return Row(rw_counter, is_write, - (key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # keys + # keys + (key0, key1, key2, key3, key4), key2_limbs, key4_bytes, # type: ignore value, (aux0, aux1)) # values # fmt: on diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 2bf9d350f..8817bc997 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -12,16 +12,16 @@ def rand_range(stop: Union[int, float] = 2**256) -> int: return randrange(0, int(stop)) -def rand_fp() -> FQ: +def rand_fq() -> FQ: return FQ(rand_range(FQ.field_modulus)) def rand_address() -> U160: - return rand_range(2**160) + return U160(rand_range(2**160)) def rand_word() -> U256: - return rand_range(2**256) + return U256(rand_range(2**256)) def rand_bytes(n_bytes: int = 32) -> bytes: diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index e91d14635..3bd5325c2 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -1,61 +1,56 @@ from __future__ import annotations -from typing import Sequence, Union -from py_ecc.fields import bn128_FQ as FQ +from typing import Protocol, Sequence, Type, TypeVar, Union +from functools import reduce +from py_ecc import bn128 -def _hash_fq(v: FQ) -> int: - return hash(v.n) +class FQ(bn128.FQ): + def __init__(self, value: IntOrFQ) -> None: + if isinstance(value, FQ): + self.n = value.n + else: + super().__init__(value) + def __hash__(self) -> int: + return hash(self.n) -FQ.__hash__ = _hash_fq -IntOrFQ = Union[int, FQ] + def expr(self) -> FQ: + return FQ(self) + + @staticmethod + def linear_combine(le_bytes: Sequence[int], base: FQ) -> FQ: + def accumulate(acc: FQ, byte: int) -> FQ: + assert ( + 0 <= byte < 256 + ), "Each byte in le_bytes for linear combination should fit in 8-bit" + return acc * base + FQ(byte) + return reduce(accumulate, reversed(le_bytes), FQ(0)) -def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: int) -> FQ: - ret = FQ.zero() - factor = FQ(factor) - for byte in reversed(le_bytes): - assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" - ret = ret * factor + byte - return ret + +IntOrFQ = Union[int, FQ] class RLC: - le_bytes: bytes value: FQ + le_bytes: bytes def __init__( - self, int_or_bytes: Union[IntOrFQ, bytes], randomness: int, n_bytes: int = 32 + self, value: Union[int, bytes], randomness: FQ = FQ(0), n_bytes: int = None ) -> None: - if isinstance(int_or_bytes, int): - assert ( - 0 <= int_or_bytes < 256**n_bytes - ), f"Value {int_or_bytes} too large to fit {n_bytes} bytes" - self.le_bytes = int_or_bytes.to_bytes(n_bytes, "little") - elif isinstance(int_or_bytes, FQ): - assert ( - int_or_bytes.n < 256**n_bytes - ), f"Value {int_or_bytes} too large to fit {n_bytes} bytes" - self.le_bytes = int_or_bytes.n.to_bytes(n_bytes, "little") - elif isinstance(int_or_bytes, bytes): - assert ( - len(int_or_bytes) <= n_bytes - ), f"Expected bytes with length less or equal than {n_bytes}" - self.le_bytes = int_or_bytes.ljust(n_bytes, b"\x00") - else: - raise TypeError( - f"Expected an int or bytes, but got object of type {type(int_or_bytes)}" - ) + if isinstance(value, int): + value = value.to_bytes(32, "little") - self.value = fp_linear_combine(self.le_bytes, randomness) + if n_bytes is not None: + if len(value) > n_bytes: + raise ValueError(f"RLC expects to have {n_bytes} bytes, but got {len(value)} bytes") + value = value.ljust(n_bytes, b"\x00") - def __eq__(self, rhs: Union[int, FQ, RLC]): - if isinstance(rhs, (int, FQ)): - return self.value == rhs - if isinstance(rhs, RLC): - return self.value == rhs.value - else: - raise TypeError(f"Expected a RLC, but got object of type {type(rhs)}") + self.value = FQ.linear_combine(value, randomness) + self.le_bytes = value + + def expr(self) -> FQ: + return FQ(self.value) def __hash__(self) -> int: return hash(self.value) @@ -63,5 +58,16 @@ def __hash__(self) -> int: def __repr__(self) -> str: return "RLC(%s)" % int.from_bytes(self.le_bytes, "little") - def be_bytes(self) -> bytes: - return bytes(reversed(self.le_bytes)) + +class Expression(Protocol): + def expr(self) -> FQ: + ... + + +ExpressionImpl = TypeVar("ExpressionImpl", bound=Expression) + + +def cast_expr(expression: Expression, ty: Type[ExpressionImpl]) -> ExpressionImpl: + if not isinstance(expression, ty): + raise TypeError(f"Casting Expression to {ty}, but got {type(expression)}") + return expression diff --git a/src/zkevm_specs/util/hash.py b/src/zkevm_specs/util/hash.py index 99d750cff..5c71ea5e5 100644 --- a/src/zkevm_specs/util/hash.py +++ b/src/zkevm_specs/util/hash.py @@ -5,11 +5,11 @@ def keccak256(data: Union[str, bytes, bytearray]) -> bytes: - if type(data) == str: + if isinstance(data, str): data = bytes.fromhex(data) return keccak.new(digest_bits=256).update(data).digest() -EMPTY_HASH: U256 = int.from_bytes(keccak256(""), "big") +EMPTY_HASH: U256 = U256(int.from_bytes(keccak256(""), "big")) EMPTY_CODE_HASH: U256 = EMPTY_HASH -EMPTY_TRIE_HASH: U256 = int.from_bytes(keccak256("80"), "big") +EMPTY_TRIE_HASH: U256 = U256(int.from_bytes(keccak256("80"), "big")) diff --git a/src/zkevm_specs/util/typing.py b/src/zkevm_specs/util/typing.py index 0ab5b8234..e8dba8780 100644 --- a/src/zkevm_specs/util/typing.py +++ b/src/zkevm_specs/util/typing.py @@ -1,48 +1,5 @@ -from typing import NewType, Tuple +from typing import NewType U64 = NewType("U64", int) U160 = NewType("U160", int) U256 = NewType("U256", int) - - -Array3 = NewType("Array3", Tuple[int, int, int]) -Array4 = NewType("Array4", Tuple[int, int, int, int]) -Array8 = NewType("Array8", Tuple[int, int, int, int, int, int, int, int]) -Array10 = NewType("Array10", Tuple[int, int, int, int, int, int, int, int, int, int]) -Array32 = NewType( - "Array32", - Tuple[ - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - ], -) diff --git a/tests/evm/test_add.py b/tests/evm/test_add.py index c3bcf2a38..d0a5004b6 100644 --- a/tests/evm/test_add.py +++ b/tests/evm/test_add.py @@ -7,12 +7,11 @@ Opcode, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_word, RLC +from zkevm_specs.util import rand_fq, rand_word, RLC TESTING_DATA = ( @@ -25,7 +24,7 @@ @pytest.mark.parametrize("opcode, a, b, c", TESTING_DATA) def test_add(opcode: Opcode, a: int, b: int, c: Optional[int]): - randomness = rand_fp() + randomness = rand_fq() c = ( RLC(c, randomness) @@ -43,11 +42,11 @@ def test_add(opcode: Opcode, a: int, b: int, c: Optional[int]): tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - (9, RW.Read, RWTableTag.Stack, 1, 1022, 0, a, 0, 0, 0), - (10, RW.Read, RWTableTag.Stack, 1, 1023, 0, b, 0, 0, 0), - (11, RW.Write, RWTableTag.Stack, 1, 1023, 0, c, 0, 0, 0), - ] + RWDictionary(9) + .stack_read(1, 1022, a) + .stack_read(1, 1023, b) + .stack_write(1, 1023, c) + .rws ), ) diff --git a/tests/evm/test_begin_tx.py b/tests/evm/test_begin_tx.py index 6366bf892..591adf2a3 100644 --- a/tests/evm/test_begin_tx.py +++ b/tests/evm/test_begin_tx.py @@ -5,17 +5,15 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, AccountFieldTag, CallContextFieldTag, Block, Transaction, Account, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_address, rand_range, RLC -from zkevm_specs.util.hash import EMPTY_CODE_HASH +from zkevm_specs.util import rand_fq, rand_address, rand_range, RLC, EMPTY_CODE_HASH RETURN_BYTECODE = Bytecode().return_(0, 0) REVERT_BYTECODE = Bytecode().revert(0, 0) @@ -94,9 +92,9 @@ ) -@pytest.mark.parametrize("tx, callee, result", TESTING_DATA) -def test_begin_tx(tx: Transaction, callee: Account, result: bool): - randomness = rand_fp() +@pytest.mark.parametrize("tx, callee, is_success", TESTING_DATA) +def test_begin_tx(tx: Transaction, callee: Account, is_success: bool): + randomness = rand_fq() rw_counter_end_of_reversion = 23 caller_balance_prev = int(1e20) @@ -112,58 +110,28 @@ def test_begin_tx(tx: Transaction, callee: Account, result: bool): bytecode_table=set(callee.code.table_assignments(randomness)), rw_table=set( # fmt: off - [ - (1, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, 0, tx.id, 0, 0, 0), - (2, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.RwCounterEndOfReversion, 0, 0 if result else rw_counter_end_of_reversion, 0, 0, 0), - (3, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.IsPersistent, 0, result, 0, 0, 0), - (4, RW.Write, RWTableTag.Account, tx.caller_address, AccountFieldTag.Nonce, 0, tx.nonce + 1, tx.nonce, 0, 0), - (5, RW.Write, RWTableTag.TxAccessListAccount, 1, tx.caller_address, 0, 1, 0, 0, 0), - (6, RW.Write, RWTableTag.TxAccessListAccount, 1, tx.callee_address, 0, 1, 0, 0, 0), - (7, RW.Write, RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, 0, RLC(caller_balance, randomness), RLC(caller_balance_prev, randomness), 0, 0), - (8, RW.Write, RWTableTag.Account, tx.callee_address, AccountFieldTag.Balance, 0, RLC(callee_balance, randomness), RLC(callee_balance_prev, randomness), 0, 0), - (9, RW.Read, RWTableTag.Account, tx.callee_address, AccountFieldTag.CodeHash, 0, bytecode_hash, bytecode_hash, 0, 0), - (10, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.Depth, 0, 1, 0, 0, 0), - (11, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, 0, tx.caller_address, 0, 0, 0), - (12, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CalleeAddress, 0, tx.callee_address, 0, 0, 0), - (13, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallDataOffset, 0, 0, 0, 0, 0), - (14, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallDataLength, 0, len(tx.call_data), 0, 0, 0), - (15, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.Value, 0, RLC(tx.value, randomness), 0, 0, 0), - (16, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.IsStatic, 0, 0, 0, 0, 0), - (17, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.LastCalleeId, 0, 0, 0, 0, 0), - (18, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.LastCalleeReturnDataOffset, 0, 0, 0, 0, 0), - (19, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.LastCalleeReturnDataLength, 0, 0, 0, 0, 0), - ] + RWDictionary(1) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .call_context_read(1, CallContextFieldTag.RwCounterEndOfReversion, 0 if is_success else rw_counter_end_of_reversion) + .call_context_read(1, CallContextFieldTag.IsPersistent, is_success) + .account_write(tx.caller_address, AccountFieldTag.Nonce, tx.nonce + 1, tx.nonce) + .tx_access_list_account_write(tx.id, tx.caller_address, True, False) + .tx_access_list_account_write(tx.id, tx.callee_address, True, False) + .account_write(tx.caller_address, AccountFieldTag.Balance, RLC(caller_balance, randomness), RLC(caller_balance_prev, randomness), rw_counter_of_reversion=None if is_success else rw_counter_end_of_reversion) + .account_write(tx.callee_address, AccountFieldTag.Balance, RLC(callee_balance, randomness), RLC(callee_balance_prev, randomness), rw_counter_of_reversion=None if is_success else rw_counter_end_of_reversion - 1) + .account_read(tx.callee_address, AccountFieldTag.CodeHash, bytecode_hash) + .call_context_read(1, CallContextFieldTag.Depth, 1) + .call_context_read(1, CallContextFieldTag.CallerAddress, tx.caller_address) + .call_context_read(1, CallContextFieldTag.CalleeAddress, tx.callee_address) + .call_context_read(1, CallContextFieldTag.CallDataOffset, 0) + .call_context_read(1, CallContextFieldTag.CallDataLength, len(tx.call_data)) + .call_context_read(1, CallContextFieldTag.Value, RLC(tx.value, randomness)) + .call_context_read(1, CallContextFieldTag.IsStatic, 0) + .call_context_read(1, CallContextFieldTag.LastCalleeId, 0) + .call_context_read(1, CallContextFieldTag.LastCalleeReturnDataOffset, 0) + .call_context_read(1, CallContextFieldTag.LastCalleeReturnDataLength, 0) + .rws, # fmt: on - + ( - [] - if result - else [ - ( - rw_counter_end_of_reversion - 1, - RW.Write, - RWTableTag.Account, - tx.callee_address, - AccountFieldTag.Balance, - 0, - RLC(callee_balance_prev, randomness), - RLC(callee_balance, randomness), - 0, - 0, - ), - ( - rw_counter_end_of_reversion, - RW.Write, - RWTableTag.Account, - tx.caller_address, - AccountFieldTag.Balance, - 0, - RLC(caller_balance_prev, randomness), - RLC(caller_balance, randomness), - 0, - 0, - ), - ] - ) ), ) diff --git a/tests/evm/test_calldatacopy.py b/tests/evm/test_calldatacopy.py index 63c9c29fc..471143ebf 100644 --- a/tests/evm/test_calldatacopy.py +++ b/tests/evm/test_calldatacopy.py @@ -1,5 +1,5 @@ import pytest -from typing import Sequence, Tuple, Mapping, Optional +from typing import Sequence, Tuple, Mapping from zkevm_specs.evm import ( Opcode, @@ -15,10 +15,11 @@ Block, Transaction, Bytecode, + RWDictionary, ) from zkevm_specs.evm.execution.memory_copy import MAX_COPY_BYTES from zkevm_specs.util import ( - rand_fp, + rand_fq, rand_bytes, GAS_COST_COPY, MEMORY_EXPANSION_QUAD_DENOMINATOR, @@ -56,7 +57,7 @@ def make_copy_step( src_addr_end: int, bytes_left: int, from_tx: bool, - rw_counter: int, + rw_dictionary: RWDictionary, program_counter: int, stack_pointer: int, memory_size: int, @@ -73,7 +74,7 @@ def make_copy_step( ) step = StepState( execution_state=ExecutionState.CopyToMemory, - rw_counter=rw_counter, + rw_counter=rw_dictionary.rw_counter, call_id=1, is_root=from_tx, program_counter=program_counter, @@ -84,42 +85,14 @@ def make_copy_step( aux_data=aux_data, ) - rws = [] num_bytes = min(MAX_COPY_BYTES, bytes_left) for i in range(num_bytes): byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 if not from_tx and src_addr + i < src_addr_end: - rws.append( - ( - rw_counter, - RW.Read, - RWTableTag.Memory, - CALL_ID, - src_addr + i, - 0, - byte, - 0, - 0, - 0, - ) - ) - rw_counter += 1 - rws.append( - ( - rw_counter, - RW.Write, - RWTableTag.Memory, - CALL_ID, - dst_addr + i, - 0, - byte, - 0, - 0, - 0, - ) - ) - rw_counter += 1 - return step, rws + rw_dictionary.memory_read(CALL_ID, src_addr + i, byte) + rw_dictionary.memory_write(CALL_ID, dst_addr + i, byte) + + return step def make_copy_steps( @@ -129,28 +102,26 @@ def make_copy_steps( dst_addr: int, length: int, from_tx: bool, - rw_counter: int, + rw_dictionary: RWDictionary, program_counter: int, stack_pointer: int, memory_size: int, gas_left: int, code_source: RLC, -) -> Tuple[Sequence[StepState], Sequence[RW]]: +) -> Sequence[StepState]: buffer_addr_end = buffer_addr + len(buffer) buffer_map = dict(zip(range(buffer_addr, buffer_addr_end), buffer)) steps = [] - rws = [] bytes_left = length while bytes_left > 0: - curr_rw_counter = rws[-1][0] + 1 if rws else rw_counter - new_step, new_rws = make_copy_step( + new_step = make_copy_step( buffer_map, src_addr, dst_addr, buffer_addr_end, bytes_left, from_tx, - curr_rw_counter, + rw_dictionary, program_counter, stack_pointer, memory_size, @@ -158,11 +129,10 @@ def make_copy_steps( code_source, ) steps.append(new_step) - rws.extend(new_rws) src_addr += MAX_COPY_BYTES dst_addr += MAX_COPY_BYTES bytes_left -= MAX_COPY_BYTES - return steps, rws + return steps def memory_gas_cost(memory_word_size: int) -> int: @@ -190,7 +160,7 @@ def test_calldatacopy( from_tx: bool, call_data_offset: int, ): - randomness = rand_fp() + randomness = rand_fq() bytecode = Bytecode().calldatacopy(memory_offset, data_offset, length) bytecode_hash = RLC(bytecode.hash(), randomness) @@ -229,50 +199,27 @@ def test_calldatacopy( gas_left=gas, ) ] - rws = [ - (1, RW.Read, RWTableTag.Stack, CALL_ID, 1021, 0, memory_offset_rlc, 0, 0, 0), - (2, RW.Read, RWTableTag.Stack, CALL_ID, 1022, 0, data_offset_rlc, 0, 0, 0), - (3, RW.Read, RWTableTag.Stack, CALL_ID, 1023, 0, length_rlc, 0, 0, 0), - (4, RW.Read, RWTableTag.CallContext, CALL_ID, CallContextFieldTag.TxId, 0, TX_ID, 0, 0, 0), - ] + + rw_dictionary = ( + RWDictionary(1) + .stack_read(CALL_ID, 1021, memory_offset_rlc) + .stack_read(CALL_ID, 1022, data_offset_rlc) + .stack_read(CALL_ID, 1023, length_rlc) + .call_context_read(CALL_ID, CallContextFieldTag.TxId, TX_ID) + ) if not from_tx: - rws.append( - ( - 5, - RW.Read, - RWTableTag.CallContext, - CALL_ID, - CallContextFieldTag.CallDataLength, - 0, - call_data_length, - 0, - 0, - 0, - ) - ) - rws.append( - ( - 6, - RW.Read, - RWTableTag.CallContext, - CALL_ID, - CallContextFieldTag.CallDataOffset, - 0, - call_data_offset, - 0, - 0, - 0, - ) - ) + rw_dictionary.call_context_read( + CALL_ID, CallContextFieldTag.CallDataLength, call_data_length + ).call_context_read(CALL_ID, CallContextFieldTag.CallDataOffset, call_data_offset) - new_steps, new_rws = make_copy_steps( + new_steps = make_copy_steps( call_data, call_data_offset, call_data_offset + data_offset, memory_offset, length, from_tx, - rw_counter=rws[-1][0] + 1, + rw_dictionary=rw_dictionary, program_counter=100, memory_size=next_memory_word_size, stack_pointer=1024, @@ -280,12 +227,11 @@ def test_calldatacopy( code_source=bytecode_hash, ) steps.extend(new_steps) - rws.extend(new_rws) steps.append( StepState( execution_state=ExecutionState.STOP, - rw_counter=rws[-1][0] + 1, + rw_counter=rw_dictionary.rw_counter, call_id=CALL_ID, is_root=from_tx, is_create=False, @@ -301,7 +247,7 @@ def test_calldatacopy( block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set(rws), + rw_table=set(rw_dictionary.rws), ) verify_steps( diff --git a/tests/evm/test_calldataload.py b/tests/evm/test_calldataload.py index a207217f4..fa7e0a923 100644 --- a/tests/evm/test_calldataload.py +++ b/tests/evm/test_calldataload.py @@ -5,14 +5,13 @@ Bytecode, CallContextFieldTag, ExecutionState, - RW, - RWTableTag, StepState, Tables, Transaction, verify_steps, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC, U64 +from zkevm_specs.util import rand_fq, RLC, U64 TESTING_DATA = ( ( @@ -21,7 +20,7 @@ 0x00, bytes.fromhex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"), True, - None, + 0, ), ( bytes.fromhex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"), @@ -29,7 +28,7 @@ 0x1F, bytes.fromhex("FF00000000000000000000000000000000000000000000000000000000000000"), True, - None, + 0, ), ( bytes.fromhex("a1bacf5488bfafc33bad736db41f06866eaeb35e1c1dd81dfc268357ec98563f"), @@ -37,7 +36,7 @@ 0x10, bytes.fromhex("6eaeb35e1c1dd81dfc268357ec98563f00000000000000000000000000000000"), True, - None, + 0, ), ( bytes.fromhex("a1bacf5488bfafc33bad736db41f06866eaeb35e1c1dd81dfc268357ec98563f"), @@ -68,9 +67,9 @@ def test_calldataload( offset: U64, expected_stack_top: bytes, is_root: bool, - call_data_offset: Optional[U64], + call_data_offset: U64, ): - randomness = rand_fp() + randomness = rand_fq() tx = Transaction(id=1) if is_root: @@ -88,104 +87,38 @@ def test_calldataload( call_id = 2 parent_call_id = 1 - rws = set( - [ - (1, RW.Write, RWTableTag.Stack, call_id, 1023, 0, offset_rlc, 0, 0, 0), - (2, RW.Read, RWTableTag.Stack, call_id, 1023, 0, offset_rlc, 0, 0, 0), - (3, RW.Read, RWTableTag.CallContext, call_id, CallContextFieldTag.TxId, 0, 1, 0, 0, 0), - ] + rw_dictionary = ( + RWDictionary(1) + .stack_write(call_id, 1023, offset_rlc) + .stack_read(call_id, 1023, offset_rlc) + .call_context_read(call_id, CallContextFieldTag.TxId, 1) ) if is_root: - rws.add((4, RW.Write, RWTableTag.Stack, call_id, 1023, 0, expected_stack_top, 0, 0, 0)) - rw_counter_stop = 5 + rw_dictionary.stack_write(call_id, 1023, expected_stack_top) else: # add to RW table call context, call data length (read) - rws.add( - ( - 4, - RW.Read, - RWTableTag.CallContext, - call_id, - CallContextFieldTag.CallDataLength, - 0, - call_data_length, - 0, - 0, - 0, - ) + rw_dictionary.call_context_read( + call_id, CallContextFieldTag.CallDataLength, call_data_length ) # add to RW table call context, call data offset (read) - rws.add( - ( - 5, - RW.Read, - RWTableTag.CallContext, - call_id, - CallContextFieldTag.CallDataOffset, - 0, - call_data_offset, - 0, - 0, - 0, - ) + rw_dictionary.call_context_read( + call_id, CallContextFieldTag.CallDataOffset, call_data_offset ) # add to RW table call context, caller'd ID (read) - rws.add( - ( - 6, - RW.Read, - RWTableTag.CallContext, - call_id, - CallContextFieldTag.CallerId, - 0, - parent_call_id, - 0, - 0, - 0, - ) - ) - rw_counter = 7 + rw_dictionary.call_context_read(call_id, CallContextFieldTag.CallerId, parent_call_id) # add to RW table memory (read) for i in range(0, len(call_data)): idx = offset + call_data_offset + i if idx < len(call_data): - rws.add( - ( - rw_counter, - RW.Read, - RWTableTag.Memory, - parent_call_id, - idx, - 0, - call_data[idx], - 0, - 0, - 0, - ) - ) - rw_counter += 1 + rw_dictionary.memory_read(parent_call_id, idx, call_data[idx]) # add to RW table stack (write) - rws.add( - ( - rw_counter, - RW.Write, - RWTableTag.Stack, - call_id, - 1023, - 0, - expected_stack_top, - 0, - 0, - 0, - ) - ) - rw_counter_stop = rw_counter + 1 + rw_dictionary.stack_write(call_id, 1023, expected_stack_top) tables = Tables( block_table=set(), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=rws, + rw_table=rw_dictionary.rws, ) verify_steps( @@ -216,7 +149,7 @@ def test_calldataload( ), StepState( execution_state=ExecutionState.STOP, - rw_counter=rw_counter_stop, + rw_counter=rw_dictionary.rw_counter, call_id=call_id, is_root=is_root, is_create=False, diff --git a/tests/evm/test_calldatasize.py b/tests/evm/test_calldatasize.py index 3f8a1879a..0ad7e08c0 100644 --- a/tests/evm/test_calldatasize.py +++ b/tests/evm/test_calldatasize.py @@ -5,13 +5,11 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, CallContextFieldTag, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC, U64 -from zkevm_specs.util.param import N_BYTES_U64 +from zkevm_specs.util import rand_fq, RLC, U64 TESTING_DATA = ( @@ -23,7 +21,7 @@ @pytest.mark.parametrize("calldatasize", TESTING_DATA) def test_calldatasize(calldatasize: U64): - randomness = rand_fp() + randomness = rand_fq() bytecode = Bytecode().calldatasize() bytecode_hash = RLC(bytecode.hash(), randomness) @@ -33,12 +31,10 @@ def test_calldatasize(calldatasize: U64): tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - # fmt: off - (9, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallDataLength, 0, calldatasize, 0, 0, 0), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(calldatasize, randomness, N_BYTES_U64), 0, 0, 0), - # fmt: on - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.CallDataLength, calldatasize) + .stack_write(1, 1023, RLC(calldatasize, randomness)) + .rws ), ) diff --git a/tests/evm/test_caller.py b/tests/evm/test_caller.py index efe13bd26..19abbd5f3 100644 --- a/tests/evm/test_caller.py +++ b/tests/evm/test_caller.py @@ -5,13 +5,11 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, CallContextFieldTag, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_address, rand_fp, RLC, U160 -from zkevm_specs.util.param import N_BYTES_ACCOUNT_ADDRESS +from zkevm_specs.util import rand_address, rand_fq, RLC, U160 TESTING_DATA = ( @@ -25,7 +23,7 @@ @pytest.mark.parametrize("caller", TESTING_DATA) def test_caller(caller: U160): - randomness = rand_fp() + randomness = rand_fq() bytecode = Bytecode().caller() bytecode_hash = RLC(bytecode.hash(), randomness) @@ -35,12 +33,10 @@ def test_caller(caller: U160): tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - # fmt: off - (9, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CallerAddress, 0, caller, 0, 0, 0), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(caller, randomness, N_BYTES_ACCOUNT_ADDRESS), 0, 0, 0), - # fmt: on - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.CallerAddress, caller) + .stack_write(1, 1023, RLC(caller, randomness)) + .rws ), ) diff --git a/tests/evm/test_callvalue.py b/tests/evm/test_callvalue.py index f5a1f5ed5..d8b7b0000 100644 --- a/tests/evm/test_callvalue.py +++ b/tests/evm/test_callvalue.py @@ -3,16 +3,13 @@ from zkevm_specs.evm import ( ExecutionState, StepState, - Opcode, verify_steps, Tables, - RWTableTag, - RW, CallContextFieldTag, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC, U256 -from zkevm_specs.util.param import N_BYTES_WORD +from zkevm_specs.util import rand_fq, RLC, U256 TESTING_DATA = ( @@ -25,9 +22,7 @@ @pytest.mark.parametrize("callvalue", TESTING_DATA) def test_callvalue(callvalue: U256): - randomness = rand_fp() - - callvalue_rlc = RLC(callvalue, randomness, N_BYTES_WORD) + randomness = rand_fq() bytecode = Bytecode().callvalue() bytecode_hash = RLC(bytecode.hash(), randomness) @@ -37,12 +32,10 @@ def test_callvalue(callvalue: U256): tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - # fmt: off - (9, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.Value, 0, callvalue_rlc, 0, 0, 0), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, callvalue_rlc, 0, 0, 0), - # fmt: on - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.Value, RLC(callvalue, randomness)) + .stack_write(1, 1023, RLC(callvalue, randomness)) + .rws ), ) diff --git a/tests/evm/test_coinbase.py b/tests/evm/test_coinbase.py index 482753baf..6d80f7770 100644 --- a/tests/evm/test_coinbase.py +++ b/tests/evm/test_coinbase.py @@ -5,12 +5,11 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_address, rand_fp, RLC, U160 +from zkevm_specs.util import rand_address, rand_fq, RLC, U160 TESTING_DATA = (0x030201, rand_address()) @@ -18,7 +17,7 @@ @pytest.mark.parametrize("coinbase", TESTING_DATA) def test_coinbase(coinbase: U160): - randomness = rand_fp() + randomness = rand_fq() block = Block(coinbase=coinbase) @@ -29,11 +28,7 @@ def test_coinbase(coinbase: U160): block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(coinbase, randomness, 20), 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(9).stack_write(1, 1023, RLC(coinbase, randomness)).rws), ) verify_steps( diff --git a/tests/evm/test_end_block.py b/tests/evm/test_end_block.py index 68f02c64c..8029e3501 100644 --- a/tests/evm/test_end_block.py +++ b/tests/evm/test_end_block.py @@ -7,19 +7,20 @@ verify_steps, Tables, RWTableTag, + RWTableRow, RW, CallContextFieldTag, Block, Transaction, ) -from zkevm_specs.util import rand_fp +from zkevm_specs.util import rand_fq, FQ TESTING_DATA = (False, True) @pytest.mark.parametrize("is_last_step", TESTING_DATA) def test_end_block(is_last_step: bool): - randomness = rand_fp() + randomness = rand_fq() tx = Transaction() @@ -30,8 +31,8 @@ def test_end_block(is_last_step: bool): rw_table=set( chain( # dummy read/write for counting - [(i, *7 * [0]) for i in range(22)], - [(22, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, 0, tx.id, 0, 0, 0)] # fmt: skip + [RWTableRow(FQ(i), *9 * [FQ(0)]) for i in range(22)], + [RWTableRow(FQ(22), FQ(RW.Read), FQ(RWTableTag.CallContext), FQ(1), FQ(CallContextFieldTag.TxId), value=FQ(tx.id))] # fmt: skip if is_last_step else [], ) ), diff --git a/tests/evm/test_end_tx.py b/tests/evm/test_end_tx.py index 1338df495..33f856ad7 100644 --- a/tests/evm/test_end_tx.py +++ b/tests/evm/test_end_tx.py @@ -5,14 +5,13 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, AccountFieldTag, CallContextFieldTag, Block, Transaction, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC, EMPTY_CODE_HASH, MAX_REFUND_QUOTIENT_OF_GAS_USED +from zkevm_specs.util import rand_fq, RLC, EMPTY_CODE_HASH, MAX_REFUND_QUOTIENT_OF_GAS_USED CALLEE_ADDRESS = 0xFF @@ -49,7 +48,7 @@ @pytest.mark.parametrize("tx, gas_left, refund, is_last_tx", TESTING_DATA) def test_end_tx(tx: Transaction, gas_left: int, refund: int, is_last_tx: bool): - randomness = rand_fp() + randomness = rand_fq() block = Block() effective_refund = min(refund, (tx.gas - gas_left) // MAX_REFUND_QUOTIENT_OF_GAS_USED) @@ -58,25 +57,23 @@ def test_end_tx(tx: Transaction, gas_left: int, refund: int, is_last_tx: bool): coinbase_balance_prev = 0 coinbase_balance = coinbase_balance_prev + (tx.gas - gas_left) * (tx.gas_price - block.base_fee) + rw_dictionary = ( + # fmt: off + RWDictionary(17) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .tx_refund_read(tx.id, refund) + .account_write(tx.caller_address, AccountFieldTag.Balance, RLC(caller_balance, randomness), RLC(caller_balance_prev, randomness)) + .account_write(block.coinbase, AccountFieldTag.Balance, RLC(coinbase_balance, randomness), RLC(coinbase_balance_prev, randomness)) + # fmt: on + ) + if not is_last_tx: + rw_dictionary.call_context_read(22, CallContextFieldTag.TxId, tx.id + 1) + tables = Tables( block_table=set(block.table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(), - rw_table=set( - [ - # fmt: off - (17, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.TxId, 0, tx.id, 0, 0, 0), - (18, RW.Read, RWTableTag.TxRefund, tx.id, 0, 0, refund, refund, 0, 0), - (19, RW.Write, RWTableTag.Account, tx.caller_address, AccountFieldTag.Balance, 0, RLC(caller_balance, randomness), RLC(caller_balance_prev, randomness), 0, 0), - (20, RW.Write, RWTableTag.Account, block.coinbase, AccountFieldTag.Balance, 0, RLC(coinbase_balance, randomness), RLC(coinbase_balance_prev, randomness), 0, 0), - # fmt: on - ] - + ( - [] - if is_last_tx - else [(21, RW.Read, RWTableTag.CallContext, 22, CallContextFieldTag.TxId, 0, tx.id + 1, 0, 0, 0)] # fmt: skip - ) - ), + rw_table=set(rw_dictionary.rws), ) verify_steps( diff --git a/tests/evm/test_gas.py b/tests/evm/test_gas.py index 97a3b560b..ff496c02d 100644 --- a/tests/evm/test_gas.py +++ b/tests/evm/test_gas.py @@ -3,17 +3,14 @@ from zkevm_specs.evm import ( Block, Bytecode, - CallContextFieldTag, ExecutionState, StepState, - Opcode, - RW, - RWTableTag, Tables, Transaction, verify_steps, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_range, RLC +from zkevm_specs.util import rand_fq, rand_range, RLC # Start with different values for `gas` before calling the `GAS` opcode. TESTING_DATA = tuple([i for i in range(2, 10)] + [rand_range(2**64) for i in range(0, 10)]) @@ -21,9 +18,9 @@ @pytest.mark.parametrize("gas", TESTING_DATA) def test_gas(gas: int): - randomness = rand_fp() + randomness = rand_fq() - tx = Transaction(gas=gas) + tx = Transaction() bytecode = Bytecode().gas().stop() bytecode_hash = RLC(bytecode.hash(), randomness) @@ -36,11 +33,7 @@ def test_gas(gas: int): block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (2, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(gas_left, randomness), 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(2).stack_write(1, 1023, RLC(gas_left, randomness)).rws), ) verify_steps( diff --git a/tests/evm/test_gasprice.py b/tests/evm/test_gasprice.py index 8b98d0fd9..4f6d71c68 100644 --- a/tests/evm/test_gasprice.py +++ b/tests/evm/test_gasprice.py @@ -6,15 +6,12 @@ CallContextFieldTag, ExecutionState, StepState, - Opcode, - RW, - RWTableTag, Tables, Transaction, verify_steps, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_range, RLC -from zkevm_specs.util.typing import U256 +from zkevm_specs.util import rand_fq, RLC, U256 TESTING_DATA = ( 0x00, @@ -26,7 +23,7 @@ @pytest.mark.parametrize("gasprice", TESTING_DATA) def test_gasprice(gasprice: U256): - randomness = rand_fp() + randomness = rand_fq() tx = Transaction(gas_price=gasprice) @@ -34,25 +31,14 @@ def test_gasprice(gasprice: U256): bytecode_hash = RLC(bytecode.hash(), randomness) tables = Tables( - block_table=set(), + block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - ( - 9, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.TxId, - 0, - tx.id, - 0, - 0, - 0, - ), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(gasprice, randomness), 0, 0, 0), - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .stack_write(1, 1023, RLC(gasprice, randomness)) + .rws ), ) diff --git a/tests/evm/test_jump.py b/tests/evm/test_jump.py index c30adb868..ea1818be5 100644 --- a/tests/evm/test_jump.py +++ b/tests/evm/test_jump.py @@ -6,12 +6,11 @@ Opcode, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC +from zkevm_specs.util import rand_fq, RLC TESTING_DATA = ((Opcode.JUMP, bytes([7])),) @@ -19,7 +18,7 @@ @pytest.mark.parametrize("opcode, dest_bytes", TESTING_DATA) def test_jump(opcode: Opcode, dest_bytes: bytes): - randomness = rand_fp() + randomness = rand_fq() dest = RLC(bytes(reversed(dest_bytes)), randomness) block = Block() @@ -32,11 +31,7 @@ def test_jump(opcode: Opcode, dest_bytes: bytes): block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Read, RWTableTag.Stack, 1, 1021, 0, dest, 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(9).stack_read(1, 1021, dest).rws), ) verify_steps( diff --git a/tests/evm/test_jumpi.py b/tests/evm/test_jumpi.py index ae49736af..11d70791f 100644 --- a/tests/evm/test_jumpi.py +++ b/tests/evm/test_jumpi.py @@ -6,12 +6,11 @@ Opcode, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, RLC +from zkevm_specs.util import rand_fq, RLC TESTING_DATA = ((Opcode.JUMPI, bytes([40]), bytes([7])),) @@ -19,7 +18,7 @@ @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA) def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - randomness = rand_fp() + randomness = rand_fq() cond = RLC(bytes(reversed(cond_bytes)), randomness) dest = RLC(bytes(reversed(dest_bytes)), randomness) @@ -33,12 +32,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Read, RWTableTag.Stack, 1, 1021, 0, dest, 0, 0, 0), - (10, RW.Read, RWTableTag.Stack, 1, 1022, 0, cond, 0, 0, 0), - ], - ), + rw_table=set(RWDictionary(9).stack_read(1, 1021, dest).stack_read(1, 1022, cond).rws), ) verify_steps( @@ -76,7 +70,7 @@ def test_jumpi_cond_nonzero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes @pytest.mark.parametrize("opcode, cond_bytes, dest_bytes", TESTING_DATA_ZERO_COND) def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): - randomness = rand_fp() + randomness = rand_fq() cond = RLC(bytes(reversed(cond_bytes)), randomness) dest = RLC(bytes(reversed(dest_bytes)), randomness) @@ -90,12 +84,7 @@ def test_jumpi_cond_zero(opcode: Opcode, cond_bytes: bytes, dest_bytes: bytes): block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Read, RWTableTag.Stack, 1, 1021, 0, dest, 0, 0, 0), - (10, RW.Read, RWTableTag.Stack, 1, 1022, 0, cond, 0, 0, 0), - ], - ), + rw_table=set(RWDictionary(9).stack_read(1, 1021, dest).stack_read(1, 1022, cond).rws), ) verify_steps( diff --git a/tests/evm/test_number.py b/tests/evm/test_number.py index e3245c1d8..8c0508755 100644 --- a/tests/evm/test_number.py +++ b/tests/evm/test_number.py @@ -5,12 +5,11 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_address, rand_fp, RLC, U256, rand_word +from zkevm_specs.util import rand_fq, RLC, U256, rand_word TESTING_DATA = (0, 1, 2**256 - 1, rand_word()) @@ -18,7 +17,7 @@ @pytest.mark.parametrize("number", TESTING_DATA) def test_number(number: U256): - randomness = rand_fp() + randomness = rand_fq() block = Block(number=number) @@ -29,11 +28,7 @@ def test_number(number: U256): block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(number, randomness), 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(9).stack_write(1, 1023, RLC(number, randomness)).rws), ) verify_steps( diff --git a/tests/evm/test_origin.py b/tests/evm/test_origin.py index 0bb150736..c9ab32eef 100644 --- a/tests/evm/test_origin.py +++ b/tests/evm/test_origin.py @@ -5,15 +5,12 @@ CallContextFieldTag, ExecutionState, StepState, - RW, - RWTableTag, Tables, Transaction, verify_steps, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_address, RLC -from zkevm_specs.util.typing import U256 -from zkevm_specs.util.param import N_BYTES_ACCOUNT_ADDRESS +from zkevm_specs.util import rand_fq, rand_address, RLC, U256 TESTING_DATA = ( 0x00, @@ -26,7 +23,7 @@ @pytest.mark.parametrize("origin", TESTING_DATA) def test_origin(origin: U256): - randomness = rand_fp() + randomness = rand_fq() tx = Transaction(caller_address=origin) @@ -38,32 +35,10 @@ def test_origin(origin: U256): tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - ( - 9, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.TxId, - 0, - tx.id, - 0, - 0, - 0, - ), - ( - 10, - RW.Write, - RWTableTag.Stack, - 1, - 1023, - 0, - RLC(origin, randomness, N_BYTES_ACCOUNT_ADDRESS), - 0, - 0, - 0, - ), - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .stack_write(1, 1023, RLC(origin, randomness)) + .rws ), ) diff --git a/tests/evm/test_push.py b/tests/evm/test_push.py index 2262db9ed..f73e9c7c9 100644 --- a/tests/evm/test_push.py +++ b/tests/evm/test_push.py @@ -5,12 +5,11 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_bytes, rand_fp, RLC +from zkevm_specs.util import rand_bytes, rand_fq, RLC TESTING_DATA = tuple( @@ -26,9 +25,9 @@ @pytest.mark.parametrize("value_be_bytes", TESTING_DATA) def test_push(value_be_bytes: bytes): - randomness = rand_fp() + randomness = rand_fq() - value = RLC(bytes(reversed(value_be_bytes)), randomness) + value = RLC(bytes(reversed(value_be_bytes)), randomness, 32) bytecode = Bytecode().push(value_be_bytes, n_bytes=len(value_be_bytes)) bytecode_hash = RLC(bytecode.hash(), randomness) @@ -37,11 +36,7 @@ def test_push(value_be_bytes: bytes): block_table=set(Block().table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (8, RW.Write, RWTableTag.Stack, 1, 1023, 0, value, 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(8).stack_write(1, 1023, value).rws), ) verify_steps( diff --git a/tests/evm/test_selfbalance.py b/tests/evm/test_selfbalance.py index 46c7c76c0..bf4124665 100644 --- a/tests/evm/test_selfbalance.py +++ b/tests/evm/test_selfbalance.py @@ -5,39 +5,34 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, CallContextFieldTag, AccountFieldTag, + RWDictionary, ) -from zkevm_specs.util import rand_address, rand_word, rand_fp, RLC, U256, U160 +from zkevm_specs.util import rand_address, rand_word, rand_fq, RLC, U256, U160 TESTING_DATA = [(0, 0), (0, 10), (rand_address(), rand_word())] @pytest.mark.parametrize("callee_address, balance", TESTING_DATA) def test_selfbalance(callee_address: U160, balance: U256): - randomness = rand_fp() + randomness = rand_fq() bytecode = Bytecode().selfbalance() bytecode_hash = RLC(bytecode.hash(), randomness) - rlc_balance = RLC(balance, randomness) - tables = Tables( block_table=Block(), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - # fmt: off - (9, RW.Read, RWTableTag.CallContext, 1, CallContextFieldTag.CalleeAddress, 0, callee_address, 0, 0, 0), - (10, RW.Read, RWTableTag.Account, callee_address, AccountFieldTag.Balance, 0, rlc_balance, rlc_balance, 0, 0, 0), - (11, RW.Write, RWTableTag.Stack, 1, 1023, 0, rlc_balance, 0, 0, 0), - # fmt: on - ] + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.CalleeAddress, callee_address) + .account_read(callee_address, AccountFieldTag.Balance, RLC(balance, randomness)) + .stack_write(1, 1023, RLC(balance, randomness)) + .rws ), ) diff --git a/tests/evm/test_sload.py b/tests/evm/test_sload.py index a575cdfff..6b5316c9d 100644 --- a/tests/evm/test_sload.py +++ b/tests/evm/test_sload.py @@ -5,15 +5,13 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, CallContextFieldTag, Transaction, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util.param import COLD_SLOAD_COST, WARM_STORAGE_READ_COST -from zkevm_specs.util import rand_fp, rand_address, RLC +from zkevm_specs.util import rand_fq, rand_address, RLC, COLD_SLOAD_COST, WARM_STORAGE_READ_COST TESTING_DATA = ( ( @@ -43,118 +41,38 @@ ) -@pytest.mark.parametrize("tx, storage_key_be_bytes, warm, result", TESTING_DATA) -def test_sload(tx: Transaction, storage_key_be_bytes: bytes, warm: bool, result: bool): - randomness = rand_fp() +@pytest.mark.parametrize("tx, storage_key_be_bytes, warm, is_persistent", TESTING_DATA) +def test_sload(tx: Transaction, storage_key_be_bytes: bytes, warm: bool, is_persistent: bool): + randomness = rand_fq() storage_key = RLC(bytes(reversed(storage_key_be_bytes)), randomness) bytecode = Bytecode().push32(storage_key_be_bytes).sload().stop() bytecode_hash = RLC(bytecode.hash(), randomness) - value = 2 - value_prev = 0 - value_committed = 0 + value = RLC(2, randomness) + value_committed = RLC(0, randomness) + + rw_counter_end_of_reversion = 19 + state_write_counter = 3 tables = Tables( block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - ( - 9, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.TxId, - 0, - tx.id, - 0, - 0, - 0, - ), - ( - 10, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.RwCounterEndOfReversion, - 0, - 0 if result else 19, - 0, - 0, - 0, - ), - ( - 11, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.IsPersistent, - 0, - result, - 0, - 0, - 0, - ), - ( - 12, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.CalleeAddress, - 0, - tx.callee_address, - 0, - 0, - 0, - ), - (13, RW.Read, RWTableTag.Stack, 1, 1023, 0, storage_key, 0, 0, 0), - ( - 14, - RW.Read, - RWTableTag.AccountStorage, - tx.callee_address, - storage_key, - 0, - value, - value_prev, - tx.id, - value_committed, - ), - (15, RW.Write, RWTableTag.Stack, 1, 1023, 0, value, 0, 0, 0), - ( - 16, - RW.Write, - RWTableTag.TxAccessListAccountStorage, - tx.id, - tx.callee_address, - storage_key, - 1, - 1 if warm else 0, - 0, - 0, - ), - ] - + ( - [] - if result - else [ - ( - 19, - RW.Write, - RWTableTag.TxAccessListAccountStorage, - tx.id, - tx.callee_address, - storage_key, - 1 if warm else 0, - 1, - 0, - 0, - ), - ] - ) + # fmt: off + RWDictionary(9) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .call_context_read(1, CallContextFieldTag.RwCounterEndOfReversion, 0 if is_persistent else rw_counter_end_of_reversion) + .call_context_read(1, CallContextFieldTag.IsPersistent, is_persistent) + .call_context_read(1, CallContextFieldTag.CalleeAddress, tx.callee_address) + .stack_read(1, 1023, storage_key) + .account_storage_read(tx.callee_address, storage_key, value, tx.id, value_committed) + .stack_write(1, 1023, value) + .tx_access_list_account_storage_write(tx.id, tx.callee_address, storage_key, 1, 1 if warm else 0, rw_counter_of_reversion=None if is_persistent else rw_counter_end_of_reversion - state_write_counter) + .rws + # fmt: on ), ) @@ -171,11 +89,11 @@ def test_sload(tx: Transaction, storage_key_be_bytes: bytes, warm: bool, result: code_source=bytecode_hash, program_counter=33, stack_pointer=1023, - state_write_counter=0, + state_write_counter=state_write_counter, gas_left=WARM_STORAGE_READ_COST if warm else COLD_SLOAD_COST, ), StepState( - execution_state=ExecutionState.STOP if result else ExecutionState.REVERT, + execution_state=ExecutionState.STOP if is_persistent else ExecutionState.REVERT, rw_counter=17, call_id=1, is_root=True, @@ -183,7 +101,7 @@ def test_sload(tx: Transaction, storage_key_be_bytes: bytes, warm: bool, result: code_source=bytecode_hash, program_counter=34, stack_pointer=1023, - state_write_counter=1, + state_write_counter=state_write_counter + 1, gas_left=0, ), ], diff --git a/tests/evm/test_slt_sgt.py b/tests/evm/test_slt_sgt.py index f0fcb015d..3f9bc9842 100644 --- a/tests/evm/test_slt_sgt.py +++ b/tests/evm/test_slt_sgt.py @@ -6,12 +6,11 @@ Opcode, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_fp, rand_word, RLC +from zkevm_specs.util import rand_fq, rand_word, RLC RAND_1 = rand_word() @@ -186,7 +185,7 @@ @pytest.mark.parametrize("opcode, a, b, res", TESTING_DATA) def test_slt_sgt(opcode: Opcode, a: int, b: int, res: int): - randomness = rand_fp() + randomness = rand_fq() a = RLC(a, randomness) b = RLC(b, randomness) @@ -200,11 +199,11 @@ def test_slt_sgt(opcode: Opcode, a: int, b: int, res: int): tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - (9, RW.Read, RWTableTag.Stack, 1, 1022, 0, a, 0, 0, 0), - (10, RW.Read, RWTableTag.Stack, 1, 1023, 0, b, 0, 0, 0), - (11, RW.Write, RWTableTag.Stack, 1, 1023, 0, res, 0, 0, 0), - ] + RWDictionary(9) + .stack_read(1, 1022, a) + .stack_read(1, 1023, b) + .stack_write(1, 1023, res) + .rws ), ) diff --git a/tests/evm/test_sstore.py b/tests/evm/test_sstore.py index 0d11ab7c2..82a52d201 100644 --- a/tests/evm/test_sstore.py +++ b/tests/evm/test_sstore.py @@ -5,22 +5,22 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, CallContextFieldTag, Transaction, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util.param import ( +from zkevm_specs.util import ( + rand_fq, + rand_address, + RLC, COLD_SLOAD_COST, - WARM_STORAGE_READ_COST, SLOAD_GAS, SSTORE_SET_GAS, SSTORE_RESET_GAS, SSTORE_CLEARS_SCHEDULE, ) -from zkevm_specs.util import rand_fp, rand_address, RLC def gen_test_cases(): @@ -78,24 +78,24 @@ def gen_test_cases(): @pytest.mark.parametrize( - "tx, storage_key_be_bytes, value_be_bytes, value_prev_be_bytes, original_value_be_bytes, warm, result", + "tx, storage_key_be_bytes, value_be_bytes, value_prev_be_bytes, original_value_be_bytes, warm, is_success", TESTING_DATA, ) def test_sstore( tx: Transaction, storage_key_be_bytes: bytes, value_be_bytes: bytes, - value_prev_be_bytes: int, - original_value_be_bytes: int, + value_prev_be_bytes: bytes, + original_value_be_bytes: bytes, warm: bool, - result: bool, + is_success: bool, ): - randomness = rand_fp() + randomness = rand_fq() - storage_key = RLC(bytes(reversed(storage_key_be_bytes)), randomness) - value = RLC(bytes(reversed(value_be_bytes)), randomness) - value_prev = RLC(bytes(reversed(value_prev_be_bytes)), randomness) - original_value = RLC(bytes(reversed(original_value_be_bytes)), randomness) + storage_key = int.from_bytes(storage_key_be_bytes, "big") + value = int.from_bytes(value_be_bytes, "big") + value_prev = int.from_bytes(value_prev_be_bytes, "big") + value_committed = int.from_bytes(original_value_be_bytes, "big") bytecode = Bytecode().push32(storage_key_be_bytes).push32(value_be_bytes).sstore().stop() bytecode_hash = RLC(bytecode.hash(), randomness) @@ -103,8 +103,8 @@ def test_sstore( if value_prev == value: expected_gas_cost = SLOAD_GAS else: - if original_value == value_prev: - if original_value == 0: + if value_committed == value_prev: + if value_committed == 0: expected_gas_cost = SSTORE_SET_GAS else: expected_gas_cost = SSTORE_RESET_GAS @@ -113,20 +113,20 @@ def test_sstore( if not warm: expected_gas_cost = expected_gas_cost + COLD_SLOAD_COST - old_gas_refund = 15000 - gas_refund = old_gas_refund + gas_refund_prev = 15000 + gas_refund = gas_refund_prev if value_prev != value: - if original_value == value_prev: - if original_value != 0 and value == 0: + if value_committed == value_prev: + if value_committed != 0 and value == 0: gas_refund = gas_refund + SSTORE_CLEARS_SCHEDULE else: - if original_value != 0: + if value_committed != 0: if value_prev == 0: gas_refund = gas_refund - SSTORE_CLEARS_SCHEDULE if value == 0: gas_refund = gas_refund + SSTORE_CLEARS_SCHEDULE - if original_value == value: - if original_value == 0: + if value_committed == value: + if value_committed == 0: gas_refund = gas_refund + SSTORE_SET_GAS - SLOAD_GAS else: gas_refund = gas_refund + SSTORE_RESET_GAS - SLOAD_GAS @@ -136,125 +136,19 @@ def test_sstore( tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set( - [ - ( - 1, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.TxId, - 0, - tx.id, - 0, - 0, - 0, - ), - ( - 2, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.RwCounterEndOfReversion, - 0, - 0 if result else 14, - 0, - 0, - 0, - ), - ( - 3, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.IsPersistent, - 0, - result, - 0, - 0, - 0, - ), - ( - 4, - RW.Read, - RWTableTag.CallContext, - 1, - CallContextFieldTag.CalleeAddress, - 0, - tx.callee_address, - 0, - 0, - 0, - ), - (5, RW.Read, RWTableTag.Stack, 1, 1022, 0, storage_key, 0, 0, 0), - (6, RW.Read, RWTableTag.Stack, 1, 1023, 0, value, 0, 0, 0), - ( - 7, - RW.Write, - RWTableTag.AccountStorage, - tx.callee_address, - storage_key, - 0, - value, - value_prev, - tx.id, - original_value, - ), - ( - 8, - RW.Write, - RWTableTag.TxAccessListAccountStorage, - tx.id, - tx.callee_address, - storage_key, - 1, - 1 if warm else 0, - 0, - 0, - ), - (9, RW.Write, RWTableTag.TxRefund, tx.id, 0, 0, gas_refund, old_gas_refund, 0, 0), - ] - + ( - [] - if result - else [ - ( - 12, - RW.Write, - RWTableTag.TxRefund, - tx.id, - 0, - 0, - old_gas_refund, - gas_refund, - 0, - 0, - ), - ( - 13, - RW.Write, - RWTableTag.TxAccessListAccountStorage, - tx.id, - tx.callee_address, - storage_key, - 1 if warm else 0, - 1, - 0, - 0, - ), - ( - 14, - RW.Write, - RWTableTag.AccountStorage, - tx.callee_address, - storage_key, - 0, - value_prev, - value, - tx.id, - original_value, - ), - ] - ) + # fmt: off + RWDictionary(1) + .call_context_read(1, CallContextFieldTag.TxId, tx.id) + .call_context_read(1, CallContextFieldTag.RwCounterEndOfReversion, 0 if is_success else 14) + .call_context_read(1, CallContextFieldTag.IsPersistent, is_success) + .call_context_read(1, CallContextFieldTag.CalleeAddress, tx.callee_address) + .stack_read(1, 1022, RLC(storage_key, randomness)) + .stack_read(1, 1023, RLC(value, randomness)) + .account_storage_write(tx.callee_address, RLC(storage_key, randomness), RLC(value, randomness), RLC(value_prev, randomness), tx.id, RLC(value_committed, randomness), rw_counter_of_reversion=None if is_success else 14) + .tx_access_list_account_storage_write(tx.id, tx.callee_address, RLC(storage_key, randomness), 1, 1 if warm else 0, rw_counter_of_reversion=None if is_success else 13) + .tx_refund_write(tx.id, gas_refund, gas_refund_prev, rw_counter_of_reversion=None if is_success else 12) + .rws + # fmt: on ), ) @@ -275,7 +169,7 @@ def test_sstore( gas_left=expected_gas_cost, ), StepState( - execution_state=ExecutionState.STOP if result else ExecutionState.REVERT, + execution_state=ExecutionState.STOP if is_success else ExecutionState.REVERT, rw_counter=10, call_id=1, is_root=True, diff --git a/tests/evm/test_timestamp.py b/tests/evm/test_timestamp.py index 6fedadaa0..fcc4031ee 100644 --- a/tests/evm/test_timestamp.py +++ b/tests/evm/test_timestamp.py @@ -5,19 +5,18 @@ StepState, verify_steps, Tables, - RWTableTag, - RW, Block, Bytecode, + RWDictionary, ) -from zkevm_specs.util import rand_range, rand_fp, RLC, U64 +from zkevm_specs.util import rand_range, rand_fq, RLC, U64 TESTING_DATA = (0, 1, 2**64 - 1, rand_range(2**64)) @pytest.mark.parametrize("timestamp", TESTING_DATA) def test_timestamp(timestamp: U64): - randomness = rand_fp() + randomness = rand_fq() block = Block(timestamp=timestamp) @@ -28,11 +27,7 @@ def test_timestamp(timestamp: U64): block_table=set(block.table_assignments(randomness)), tx_table=set(), bytecode_table=set(bytecode.table_assignments(randomness)), - rw_table=set( - [ - (9, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(timestamp, randomness, 8), 0, 0, 0), - ] - ), + rw_table=set(RWDictionary(9).stack_write(1, 1023, RLC(timestamp, randomness)).rws), ) verify_steps( diff --git a/tests/test_bitwise.py b/tests/test_bitwise.py index 9910ffe5a..0f923fb5a 100644 --- a/tests/test_bitwise.py +++ b/tests/test_bitwise.py @@ -1,7 +1,7 @@ import random import pytest -from zkevm_specs.encoding import u256_to_u8s, u8s_to_u256 +from zkevm_specs.encoding import u256_to_u8s from zkevm_specs.opcode import check_and, check_or, check_xor @@ -28,8 +28,7 @@ def test_check_or(): def test_check_xor(): - for i in range(5): - print(i) + for _ in range(5): a = random.randint(0, 2**256) b = random.randint(0, 2**256) c = a ^ b diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index 835e1564c..270c56e24 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -1,9 +1,9 @@ -from zkevm_specs.evm.opcode import Opcode, is_push -from zkevm_specs.bytecode import * import traceback from copy import deepcopy -from zkevm_specs.evm import Bytecode -from zkevm_specs.util import RLC, rand_fp + +from zkevm_specs.bytecode import * +from zkevm_specs.evm import Opcode, Bytecode, BytecodeTableRow, is_push +from zkevm_specs.util import RLC, rand_fq # Unroll the bytecode def unroll(bytecode, randomness): @@ -28,7 +28,7 @@ def verify(k, bytecodes, randomness, success): k = 10 -randomness = rand_fp() +randomness = rand_fq() def test_bytecode_unrolling(): @@ -50,7 +50,7 @@ def test_bytecode_unrolling(): # Set the hash of the complete bytecode in the rows hash = RLC(bytes(reversed(keccak256(bytes(bytecode)))), randomness) for i in range(len(rows)): - rows[i] = (hash, rows[i][1], rows[i][2], rows[i][3]) + rows[i] = BytecodeTableRow(hash.expr(), rows[i][1], rows[i][2], rows[i][3]) # Unroll the bytecode unrolled = unroll(bytes(bytecode), randomness) # Check if the bytecode was unrolled correctly @@ -92,19 +92,19 @@ def test_bytecode_invalid_hash_data(): # Change the hash on the first position invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = BytecodeTableRow(row.bytecode_hash + 1, row.index, row.byte, row.is_code) verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = BytecodeTableRow(row.bytecode_hash + 1, row.index, row.byte, row.is_code) verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (1, row[1], row[2], row[3]) + invalid.rows[idx] = BytecodeTableRow(1, row.index, row.byte, row.is_code) verify(k, [invalid], randomness, False) @@ -115,14 +115,16 @@ def test_bytecode_invalid_index(): # Start the index at 1 invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[idx] = BytecodeTableRow( + row.bytecode_hash + 1, row.index, row.byte, row.is_code + ) verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) - invalid_cell = invalid.rows[-1][0] - invalid_cell.value -= 1 - invalid.rows[-1] = (invalid_cell, row[1], row[2], row[3]) + invalid.rows[-1] = BytecodeTableRow( + invalid.rows[-1].bytecode_hash - 1, row.index, row.byte, row.is_code + ) verify(k, [invalid], randomness, False) @@ -133,19 +135,19 @@ def test_bytecode_invalid_byte_data(): # Change the first byte invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0], row[1], row[2], 9) + invalid.rows[0] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 9) verify(k, [invalid], randomness, False) # Change a byte on another position invalid = deepcopy(unrolled) row = unrolled.rows[5] - invalid.rows[5] = (row[0], row[1], row[2], 6) + invalid.rows[5] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 6) verify(k, [invalid], randomness, False) # Set a byte value out of range invalid = deepcopy(unrolled) row = unrolled.rows[3] - invalid.rows[3] = (row[0], row[1], row[2], 256) + invalid.rows[3] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 256) verify(k, [invalid], randomness, False) @@ -169,17 +171,17 @@ def test_bytecode_invalid_is_code(): # Mark the 3rd byte as code (is push data from the first PUSH1) invalid = deepcopy(unrolled) row = unrolled.rows[2] - invalid.rows[2] = (row[0], row[1], 1, row[3]) + invalid.rows[2] = BytecodeTableRow(row.bytecode_hash, row.index, 1, row.is_code) verify(k, [invalid], randomness, False) # Mark the 4rd byte as data (is code) invalid = deepcopy(unrolled) row = unrolled.rows[3] - invalid.rows[3] = (row[0], row[1], 0, row[3]) + invalid.rows[3] = BytecodeTableRow(row.bytecode_hash, row.index, 0, row.is_code) verify(k, [invalid], randomness, False) # Mark the 7th byte as code (is data for the PUSH7) invalid = deepcopy(unrolled) row = unrolled.rows[6] - invalid.rows[6] = (row[0], row[1], 1, row[3]) + invalid.rows[6] = BytecodeTableRow(row.bytecode_hash, row.index, 1, row.is_code) verify(k, [invalid], randomness, False) diff --git a/tests/test_state_circuit.py b/tests/test_state_circuit.py index 5a431533a..4f91f7e1e 100644 --- a/tests/test_state_circuit.py +++ b/tests/test_state_circuit.py @@ -1,9 +1,10 @@ import traceback from typing import Union, List + from zkevm_specs.state import * -from zkevm_specs.util import rand_fp, FQ, RLC +from zkevm_specs.util import rand_fq, FQ, RLC -randomness = rand_fp() +randomness = rand_fq() r = randomness # Verify the state circuit with the given data