diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5a0007af9..22f778391 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,5 +27,7 @@ jobs: make install - name: Lint run: make lint + - name: Type check + run: make type - name: Test with pytest run: make test diff --git a/Makefile b/Makefile index d90c6de50..32c551cdd 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ lint: ## Check whether the code is formated correctly black . --check mdformat specs/ --number --check +type: ## Check the typing of the Python code + mypy src + test: ## Run tests pytest diff --git a/setup.cfg b/setup.cfg index e77608300..81a802224 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ test = lint = black >= 22.1.0 mdformat >= 0.7.13 + mypy >= 0.931 diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 8b7ef60b4..1beca718a 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -24,7 +24,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code @@ -178,7 +178,7 @@ def assign_push_table(): # Generate keccak table -def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int): +def assign_keccak_table(bytecodes: Sequence[bytes], randomness: FQ): keccak_table = [] for bytecode in bytecodes: hash = RLC(bytes(reversed(keccak256(bytecode))), randomness) diff --git a/src/zkevm_specs/encoding/lookup.py b/src/zkevm_specs/encoding/lookup.py index 774b892dd..808ab0d58 100644 --- a/src/zkevm_specs/encoding/lookup.py +++ b/src/zkevm_specs/encoding/lookup.py @@ -2,7 +2,7 @@ class LookupTable: - columns: Tuple[str] + columns: Tuple[str, ...] rows: Set[Tuple[int, ...]] def __init__(self, columns: Sequence[str]) -> None: diff --git a/src/zkevm_specs/encoding/utils.py b/src/zkevm_specs/encoding/utils.py index 9c0db0cd2..1b64996ee 100644 --- a/src/zkevm_specs/encoding/utils.py +++ b/src/zkevm_specs/encoding/utils.py @@ -1,8 +1,8 @@ -from typing import Sequence, Tuple, List +from typing import Sequence, Tuple from .typing import U8, U256, U64 -def is_circuit_code(func) -> object: +def is_circuit_code(func): """ A no-op decorator just to mark the function """ @@ -15,19 +15,19 @@ def wrapper(*args, **kargs): def u256_to_u8s(x: U256) -> Tuple[U8, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 8 * i) & 0xFF for i in range(32)) + return tuple(U8((x >> 8 * i) & 0xFF) for i in range(32)) def u256_to_u64s(x: U256) -> Tuple[U64, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF for i in range(4)) + return tuple(U64((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)) def u8s_to_u256(xs: Sequence[U8]) -> U256: assert len(xs) == 32 for u8 in xs: assert 0 <= u8 <= 255 - return sum(x * (2 ** (8 * i)) for i, x in enumerate(xs)) + return U256(sum(x * (2 ** (8 * i)) for i, x in enumerate(xs))) # [u8;32]->[u64;4] @@ -37,5 +37,5 @@ def u8s_to_u64s(xs: Sequence[U8]) -> Tuple[U64, ...]: A = [u64_0] * 4 # A = A3A2A1A0 for i in range(4): for j in range(8): - A[i] += U64(xs[j + 8 * i] * (2 ** (8 * j))) + A[i] += xs[j + 8 * i] * (2 ** (8 * j)) return tuple(A) diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index f282cd9cf..68cac17b3 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -1,3 +1,4 @@ +# type: ignore from typing import Callable, Dict from ..execution_state import ExecutionState diff --git a/src/zkevm_specs/evm/execution/add.py b/src/zkevm_specs/evm/execution/add.py index 8bd6f2e54..a0e09b61a 100644 --- a/src/zkevm_specs/evm/execution/add.py +++ b/src/zkevm_specs/evm/execution/add.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index 58de3305c..1fd2c9814 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,4 +1,5 @@ -from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH +# type: ignore +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH, RLC from ..execution_state import ExecutionState from ..instruction import Instruction, Transition from ..precompiled import PrecompiledAddress diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 55d0b3b20..f2ea8f845 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/block_timestamp.py b/src/zkevm_specs/evm/execution/block_timestamp.py index c503e2750..6ab10c9dc 100644 --- a/src/zkevm_specs/evm/execution/block_timestamp.py +++ b/src/zkevm_specs/evm/execution/block_timestamp.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/calldatacopy.py b/src/zkevm_specs/evm/execution/calldatacopy.py index 599b3dc58..251612d0e 100644 --- a/src/zkevm_specs/evm/execution/calldatacopy.py +++ b/src/zkevm_specs/evm/execution/calldatacopy.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import N_BYTES_MEMORY_ADDRESS, FQ from ..execution_state import ExecutionState from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/execution/calldatasize.py b/src/zkevm_specs/evm/execution/calldatasize.py index 2495577a0..1797b62cd 100644 --- a/src/zkevm_specs/evm/execution/calldatasize.py +++ b/src/zkevm_specs/evm/execution/calldatasize.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 58bf046ea..e633d6946 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/callvalue.py b/src/zkevm_specs/evm/execution/callvalue.py index 4e37a5b14..16ecb5907 100644 --- a/src/zkevm_specs/evm/execution/callvalue.py +++ b/src/zkevm_specs/evm/execution/callvalue.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/end_block.py b/src/zkevm_specs/evm/execution/end_block.py index 4ba98a1e9..6d3e81a7a 100644 --- a/src/zkevm_specs/evm/execution/end_block.py +++ b/src/zkevm_specs/evm/execution/end_block.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py index e9d295b32..76f5cb89b 100644 --- a/src/zkevm_specs/evm/execution/end_tx.py +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED from ..execution_state import ExecutionState from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/execution/gas.py b/src/zkevm_specs/evm/execution/gas.py index d0a0faea5..a56e3d237 100644 --- a/src/zkevm_specs/evm/execution/gas.py +++ b/src/zkevm_specs/evm/execution/gas.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import CallContextFieldTag, TxContextFieldTag diff --git a/src/zkevm_specs/evm/execution/gasprice.py b/src/zkevm_specs/evm/execution/gasprice.py index 3323692f1..ee2751f99 100644 --- a/src/zkevm_specs/evm/execution/gasprice.py +++ b/src/zkevm_specs/evm/execution/gasprice.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import CallContextFieldTag, TxContextFieldTag diff --git a/src/zkevm_specs/evm/execution/jump.py b/src/zkevm_specs/evm/execution/jump.py index 8e46d6d1f..b33a8cbc9 100644 --- a/src/zkevm_specs/evm/execution/jump.py +++ b/src/zkevm_specs/evm/execution/jump.py @@ -1,3 +1,4 @@ +# type: ignore from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/jumpi.py b/src/zkevm_specs/evm/execution/jumpi.py index c485b0b74..d8a2d66f0 100644 --- a/src/zkevm_specs/evm/execution/jumpi.py +++ b/src/zkevm_specs/evm/execution/jumpi.py @@ -1,3 +1,4 @@ +# type: ignore from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/memory_copy.py b/src/zkevm_specs/evm/execution/memory_copy.py index cabb93b71..43ac76008 100644 --- a/src/zkevm_specs/evm/execution/memory_copy.py +++ b/src/zkevm_specs/evm/execution/memory_copy.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import FQ, N_BYTES_MEMORY_SIZE from ..execution_state import ExecutionState from ..instruction import Instruction, Transition @@ -17,8 +18,6 @@ def copy_to_memory(instruction: Instruction): instruction, MAX_COPY_BYTES, aux.src_addr, aux.src_addr_end, aux.bytes_left ) - data = [] - rw_counter_delta = 0 for i in range(MAX_COPY_BYTES): if not buffer_reader.read_flag(i): byte = FQ.zero() diff --git a/src/zkevm_specs/evm/execution/push.py b/src/zkevm_specs/evm/execution/push.py index cb5602684..a2927c9e9 100644 --- a/src/zkevm_specs/evm/execution/push.py +++ b/src/zkevm_specs/evm/execution/push.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/selfbalance.py b/src/zkevm_specs/evm/execution/selfbalance.py index b69448757..b746573e4 100644 --- a/src/zkevm_specs/evm/execution/selfbalance.py +++ b/src/zkevm_specs/evm/execution/selfbalance.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import AccountFieldTag, CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/slt_sgt.py b/src/zkevm_specs/evm/execution/slt_sgt.py index 752ef1bc1..1d06775be 100644 --- a/src/zkevm_specs/evm/execution/slt_sgt.py +++ b/src/zkevm_specs/evm/execution/slt_sgt.py @@ -1,4 +1,6 @@ -from typing import Sequence, Tuple +# type: ignore + +from zkevm_specs.util import FQ from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -7,7 +9,7 @@ def scmp(instruction: Instruction): opcode = instruction.opcode_lookup(True) - is_sgt, _ = instruction.pair_select(opcode, Opcode.SGT, Opcode.SLT) + is_sgt, _ = instruction.pair_select(opcode, FQ(Opcode.SGT.value), FQ(Opcode.SLT.value)) a = instruction.stack_pop() b = instruction.stack_pop() diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 6cfd3c63f..d37b3d5ea 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -3,8 +3,6 @@ from typing import Optional, Sequence, Tuple, Union, Mapping from ..util import ( - Array4, - Array10, FQ, IntOrFQ, RLC, @@ -27,6 +25,11 @@ TxContextFieldTag, RW, RWTableTag, + RWTableRow, + FixedTableRow, + BlockTableRow, + TxTableRow, + BytecodeTableRow, ) @@ -49,13 +52,16 @@ def __init__(self, kind: TransitionKind, value: Optional[int] = None) -> None: self.kind = kind self.value = value - def same() -> Transition: + @classmethod + def same(cls) -> Transition: return Transition(TransitionKind.Same) - def delta(delta: int): + @classmethod + def delta(cls, delta: int): return Transition(TransitionKind.Delta, delta) - def to(to: int): + @classmethod + def to(cls, to: int): return Transition(TransitionKind.To, to) @@ -107,7 +113,7 @@ def constrain_bool(self, num: FQ): def constrain_gas_left_not_underflow(self, gas_left: FQ): self.range_check(gas_left, N_BYTES_GAS) - def constrain_step_state_transition(self, **kwargs: Mapping[str, Transition]): + def constrain_step_state_transition(self, **kwargs: Transition): keys = set( [ "rw_counter", @@ -198,7 +204,7 @@ def step_state_transition_in_same_context( ) def sum(self, values: Sequence[FQ]) -> FQ: - return sum(values) + return FQ(sum(values)) def is_zero(self, value: Union[FQ, RLC]) -> bool: return value == 0 @@ -223,8 +229,8 @@ def pair_select(self, value: FQ, lhs: FQ, rhs: FQ) -> Tuple[bool, bool]: def constant_divmod( self, numerator: IntOrFQ, denominator: IntOrFQ, n_bytes: int ) -> Tuple[FQ, FQ]: - quotient, remainder = divmod(FQ(numerator).n, FQ(denominator).n) - quotient, remainder = FQ(quotient), FQ(remainder) + _quotient, _remainder = divmod(FQ(numerator).n, FQ(denominator).n) + quotient, remainder = FQ(_quotient), FQ(_remainder) self.range_check(quotient, n_bytes) return quotient, remainder @@ -246,12 +252,11 @@ def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, FQ]: addends_lo, addends_hi = list(zip(*map(self.word_to_lo_hi, addends))) carry_lo, sum_lo = divmod(self.sum(addends_lo).n, 1 << 128) - carry_hi, sum_hi = divmod((self.sum(addends_hi) + carry_lo).n, 1 << 128) + carry_hi, sum_hi = divmod(self.sum(addends_hi).n + carry_lo, 1 << 128) sum_bytes = sum_lo.to_bytes(16, "little") + sum_hi.to_bytes(16, "little") - carry_hi = FQ(carry_hi) - return RLC(sum_bytes, self.randomness), carry_hi + return RLC(sum_bytes, self.randomness, 32), FQ(carry_hi) def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: minuend_lo, minuend_hi = self.word_to_lo_hi(minuend) @@ -273,19 +278,12 @@ def mul_word_by_u64(self, multiplicand: RLC, multiplier: FQ) -> Tuple[RLC, FQ]: quotient_hi, product_hi = divmod((multiplicand_hi * multiplier + quotient_lo).n, 1 << 128) product_bytes = product_lo.to_bytes(16, "little") + product_hi.to_bytes(16, "little") - quotient_hi = FQ(quotient_hi) - return RLC(product_bytes, self.randomness), quotient_hi + return RLC(product_bytes, self.randomness), FQ(quotient_hi) def rlc_to_le_bytes(self, rlc: RLC) -> bytes: return rlc.le_bytes - def rlc_to_fq_unchecked(self, rlc: RLC, n_bytes: int) -> FQ: - rlc_le_bytes = self.rlc_to_le_bytes(rlc) - return self.bytes_to_fq(rlc_le_bytes[:n_bytes]), self.is_zero( - self.sum(rlc_le_bytes[n_bytes:]) - ) - def rlc_to_fq_exact(self, rlc: RLC, n_bytes: int) -> FQ: rlc_le_bytes = self.rlc_to_le_bytes(rlc) @@ -311,7 +309,7 @@ def bytes_to_fq(self, value: bytes) -> FQ: return FQ(int.from_bytes(value, "little")) def range_lookup(self, value: FQ, range: int): - self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), value, 0, 0]) + self.tables.fixed_lookup(FixedTableTag.range_table_tag(range), value) def byte_range_lookup(self, value: FQ): assert isinstance(value, FQ), f"Expect type FQ, but get type {type(value)}" @@ -325,28 +323,30 @@ def range_check(self, value: FQ, n_bytes: int) -> bytes: except OverflowError: raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") - def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> Array4: - return self.tables.fixed_lookup([tag] + inputs) + def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> FixedTableRow: + return self.tables.fixed_lookup(tag, *inputs) def block_context_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ.zero()) -> FQ: - return self.tables.block_lookup([tag, index])[2] + return self.tables.block_lookup(tag, index).value def tx_context_lookup( self, tx_id: FQ, field_tag: TxContextFieldTag, index: FQ = FQ.zero() - ) -> Union[FQ, RLC]: - return self.tables.tx_lookup([tx_id, field_tag, index])[3] + ) -> FQ: + return self.tables.tx_lookup(tx_id, field_tag, index).value def tx_calldata_lookup(self, tx_id: FQ, index: FQ) -> FQ: - return self.tables.tx_lookup([tx_id, TxContextFieldTag.CallData, index])[3] + return self.tables.tx_lookup(tx_id, TxContextFieldTag.CallData, index).value def bytecode_lookup(self, bytecode_hash: RLC, index: FQ, is_code: FQ) -> FQ: - return self.tables.bytecode_lookup([bytecode_hash, index, Tables._, is_code])[2] + return self.tables.bytecode_lookup(bytecode_hash.value, index, is_code).byte def tx_gas_price(self, tx_id: FQ) -> FQ: return self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice) def responsible_opcode_lookup(self, opcode: int): - self.fixed_lookup(FixedTableTag.ResponsibleOpcode, [self.curr.execution_state, opcode]) + self.tables.fixed_lookup( + FixedTableTag.ResponsibleOpcode, FQ(self.curr.execution_state.value), FQ(opcode) + ) def opcode_lookup(self, is_code: bool) -> FQ: index = self.curr.program_counter + self.program_counter_offset @@ -359,28 +359,16 @@ def opcode_lookup_at(self, index: FQ, is_code: bool) -> FQ: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.code_source, index, is_code) + return self.bytecode_lookup(self.curr.code_source, index, FQ(is_code)) def rw_lookup( - self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None - ) -> Array10: + self, rw: RW, tag: RWTableTag, inputs: Sequence[IntOrFQ], rw_counter: Optional[FQ] = None + ) -> RWTableRow: if rw_counter is None: rw_counter = self.curr.rw_counter + self.rw_counter_offset self.rw_counter_offset += 1 - return self.tables.rw_lookup([rw_counter, rw, tag] + inputs) - - def state_write_only_persistent( - self, - tag: RWTableTag, - inputs: Sequence[int], - is_persistent: bool, - ) -> Array10: - assert tag.write_only_persistent() - - if is_persistent: - return self.rw_lookup(RW.Write, tag, inputs) - return 10 * [None] + return self.tables.rw_lookup(rw_counter, rw, tag, *inputs) def state_write_with_reversion( self, @@ -388,61 +376,73 @@ def state_write_with_reversion( inputs: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Array10: + _state_write_counter: Optional[FQ] = None, + ) -> RWTableRow: assert tag.write_with_reversion() row = self.rw_lookup(RW.Write, tag, inputs) - if state_write_counter is None: - state_write_counter = self.curr.state_write_counter + self.state_write_counter_offset + if _state_write_counter is None: self.state_write_counter_offset += 1 + state_write_counter = ( + self.curr.state_write_counter + self.state_write_counter_offset + if _state_write_counter is None + else _state_write_counter + ) + rw_counter = rw_counter_end_of_reversion - state_write_counter if not is_persistent: # Swap value and value_prev - inputs = list(row[3:]) - inputs[-3], inputs[-4] = inputs[-4], inputs[-3] - self.rw_lookup(RW.Write, tag, inputs, rw_counter=rw_counter) + inputs2 = [ + row.key2, + row.key3, + row.key4, + row.value_prev, + row.value, + row.aux1, + row.aux2, + ] + self.rw_lookup(RW.Write, tag, inputs2, rw_counter=rw_counter) return row def call_context_lookup( - self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[int] = None - ) -> Union[FQ, RLC]: - if call_id is None: - call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag])[-4] + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[FQ] = None + ) -> FQ: + _call_id = self.curr.call_id if call_id is None else call_id + + return self.rw_lookup(rw, RWTableTag.CallContext, [_call_id, field_tag.value]).value - def stack_pop(self) -> Union[FQ, RLC]: + def stack_pop(self) -> FQ: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 - return self.stack_lookup(False, stack_pointer_offset) + return self.stack_lookup(RW.Read, stack_pointer_offset) - def stack_push(self) -> Union[FQ, RLC]: + def stack_push(self) -> FQ: self.stack_pointer_offset -= 1 - return self.stack_lookup(True, self.stack_pointer_offset) + return self.stack_lookup(RW.Write, self.stack_pointer_offset) - def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[FQ, RLC]: + def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> FQ: stack_pointer = self.curr.stack_pointer + stack_pointer_offset - return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer])[-4] + return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer]).value - def memory_write(self, memory_address: int, call_id: Optional[int] = None) -> FQ: + def memory_write(self, memory_address: int, call_id: Optional[FQ] = None) -> FQ: return self.memory_lookup(RW.Write, memory_address, call_id) - def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = None) -> FQ: - if call_id is None: - call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[-4] + def memory_lookup(self, rw: RW, memory_address: int, _call_id: Optional[FQ] = None) -> FQ: + call_id = self.curr.call_id if _call_id is None else _call_id + + return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address]).value def tx_refund_read(self, tx_id) -> FQ: row = self.rw_lookup(RW.Read, RWTableTag.TxRefund, [tx_id]) - return row[-4] + return row.value def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> FQ: row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) - return row[-4] + return row.value def account_write( self, @@ -454,7 +454,7 @@ def account_write( RWTableTag.Account, [account_address, account_field_tag], ) - return row[-4], row[-3] + return row.value, row.value_prev def account_write_with_reversion( self, @@ -462,8 +462,8 @@ def account_write_with_reversion( account_field_tag: AccountFieldTag, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: + state_write_counter: Optional[FQ] = None, + ) -> Tuple[RLC, RLC]: row = self.state_write_with_reversion( RWTableTag.Account, [account_address, account_field_tag], @@ -471,23 +471,23 @@ def account_write_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4], row[-3] + return row.value, row.value_prev - def add_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: + def add_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance_prev, *values]) - self.constrain_equal(balance, result) + self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev def add_balance_with_reversion( self, account_address: int, - values: Sequence[int], + values: Sequence[RLC], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: + state_write_counter: Optional[FQ] = None, + ) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write_with_reversion( account_address, AccountFieldTag.Balance, @@ -496,25 +496,25 @@ def add_balance_with_reversion( state_write_counter, ) result, carry = self.add_words([balance_prev, *values]) - self.constrain_equal(balance, result) + self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: + def sub_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) result, carry = self.add_words([balance, *values]) - self.constrain_equal(balance_prev, result) + self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev def sub_balance_with_reversion( self, account_address: int, - values: Sequence[int], + values: Sequence[RLC], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[FQ, FQ]: + state_write_counter: Optional[FQ] = None, + ) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write_with_reversion( account_address, AccountFieldTag.Balance, @@ -523,7 +523,7 @@ def sub_balance_with_reversion( state_write_counter, ) result, carry = self.add_words([balance, *values]) - self.constrain_equal(balance_prev, result) + self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev @@ -537,7 +537,7 @@ def add_account_to_access_list( RWTableTag.TxAccessListAccount, [tx_id, account_address, 0, 1], ) - return row[-4] - row[-3] + return row.value - row.value_prev def add_account_to_access_list_with_reversion( self, @@ -545,7 +545,7 @@ def add_account_to_access_list_with_reversion( account_address: int, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> FQ: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccount, @@ -554,7 +554,7 @@ def add_account_to_access_list_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4] - row[-3] + return row.value - row.value_prev def add_account_storage_to_access_list( self, @@ -567,7 +567,7 @@ def add_account_storage_to_access_list( RWTableTag.TxAccessListAccountStorage, [tx_id, account_address, storage_key, 1], ) - return row[-4] - row[-3] + return (row.value - row.value_prev) == 1 def add_account_storage_to_access_list_with_reversion( self, @@ -576,7 +576,7 @@ def add_account_storage_to_access_list_with_reversion( storage_key: int, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> bool: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccountStorage, @@ -585,14 +585,14 @@ def add_account_storage_to_access_list_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4] - row[-3] + return (row.value - row.value_prev) == 1 def transfer_with_gas_fee( self, sender_address: int, receiver_address: int, - value: int, - gas_fee: int, + value: RLC, + gas_fee: RLC, is_persistent: bool, rw_counter_end_of_reversion: int, ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: @@ -614,11 +614,11 @@ def transfer( self, sender_address: int, receiver_address: int, - value: int, + value: RLC, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + state_write_counter: Optional[FQ] = None, + ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: sender_balance_pair = self.sub_balance_with_reversion( sender_address, [value], @@ -636,11 +636,11 @@ def transfer( return sender_balance_pair, receiver_balance_pair def memory_offset_and_length(self, offset: RLC, length: RLC) -> Tuple[FQ, FQ]: - length = self.rlc_to_fq_exact(length, N_BYTES_MEMORY_SIZE) - if self.is_zero(length): + _length = self.rlc_to_fq_exact(length, N_BYTES_MEMORY_SIZE) + if self.is_zero(_length): return FQ.zero(), FQ.zero() - offset = self.rlc_to_fq_exact(offset, N_BYTES_MEMORY_ADDRESS) - return offset, length + _offset = self.rlc_to_fq_exact(offset, N_BYTES_MEMORY_ADDRESS) + return _offset, _length def memory_gas_cost(self, memory_size: FQ) -> FQ: quadratic_cost, _ = self.constant_divmod( @@ -654,7 +654,7 @@ def memory_expansion_constant_length(self, offset: FQ, length: FQ) -> Tuple[FQ, next_memory_size = self.max(self.curr.memory_size, memory_size, N_BYTES_MEMORY_SIZE) - memory_gas_cost = self.memory_expansion_gas_cost(self.curr.memory_size) + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) memory_gas_cost_next = self.memory_gas_cost(next_memory_size) memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost @@ -665,7 +665,7 @@ def memory_expansion_dynamic_length( cd_offset: FQ, cd_length: FQ, rd_offset: Optional[FQ] = None, - rd_length: Optional[FQ] = None, + _rd_length: Optional[FQ] = None, ) -> Tuple[FQ, FQ]: cd_memory_size, _ = self.constant_divmod( cd_offset + cd_length + 31, 32, N_BYTES_MEMORY_SIZE @@ -673,6 +673,7 @@ def memory_expansion_dynamic_length( next_memory_size = self.max(self.curr.memory_size, cd_memory_size, N_BYTES_MEMORY_SIZE) if rd_offset is not None: + rd_length = 0 if _rd_length is None else _rd_length rd_memory_size, _ = self.constant_divmod( rd_offset + rd_length + 31, 32, N_BYTES_MEMORY_SIZE ) diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index f990f9100..f412c3289 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -1,3 +1,4 @@ +# type: ignore from typing import Sequence from ..util import FQ diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 8f58cdc7e..bf7acd647 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Final, Dict, Sequence, Tuple, Union +from typing import Final, Dict, Sequence, Tuple, Union, Optional from ..util.param import * @@ -184,10 +184,12 @@ class OpcodeInfo: max_stack_pointer: int constant_gas_cost: int has_dynamic_gas: bool - pure_memory_expansion_info: Tuple[ - int, # offset stack_pointer_offset - int, # length stack_pointer_offset - int, # constant length + pure_memory_expansion_info: Optional[ + Tuple[ + int, # offset stack_pointer_offset + int, # length stack_pointer_offset + int, # constant length + ] ] def __init__( @@ -196,7 +198,7 @@ def __init__( max_stack_pointer: int, constant_gas_cost: int, has_dynamic_gas: bool = False, - pure_memory_expansion_info: Union[Tuple[int, int, int], None] = None, + pure_memory_expansion_info: Optional[Tuple[int, int, int]] = None, ) -> None: self.min_stack_pointer = min_stack_pointer self.max_stack_pointer = max_stack_pointer diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index 28bcf2271..862189c8a 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -50,7 +50,7 @@ def __init__( call_id: int = 0, is_root: bool = False, is_create: bool = False, - code_source: int = 0, + code_source: RLC = RLC(0, FQ(0)), program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index c85c4e521..d8379d148 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Sequence, Set, Tuple -from enum import IntEnum, auto +from typing import Mapping, Sequence, Set, List, TypeVar, Any, Type, Optional, Dict, Union +from enum import IntEnum, auto, Enum from itertools import chain, product +from dataclasses import dataclass, field, asdict, fields -from ..util import FQ, RLC, Array3, Array4, Array10 +from ..util import FQ, IntOrFQ, RLC from .execution_state import ExecutionState from .opcode import ( invalid_opcodes, @@ -40,48 +41,58 @@ class FixedTableTag(IntEnum): StackOverflow = auto() # opcode, stack_pointer, 0 StackUnderflow = auto() # opcode, stack_pointer, 0 - def table_assignments(self) -> Sequence[Array4]: + def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range16: - return [(self, i, 0, 0) for i in range(16)] + return [FixedTableRow(self, FQ(i)) for i in range(16)] elif self == FixedTableTag.Range32: - return [(self, i, 0, 0) for i in range(32)] + return [FixedTableRow(self, FQ(i)) for i in range(32)] elif self == FixedTableTag.Range64: - return [(self, i, 0, 0) for i in range(64)] + return [FixedTableRow(self, FQ(i)) for i in range(64)] elif self == FixedTableTag.Range256: - return [(self, i, 0, 0) for i in range(256)] + return [FixedTableRow(self, FQ(i)) for i in range(256)] elif self == FixedTableTag.Range512: - return [(self, i, 0, 0) for i in range(512)] + return [FixedTableRow(self, FQ(i)) for i in range(512)] elif self == FixedTableTag.Range1024: - return [(self, i, 0, 0) for i in range(1024)] + return [FixedTableRow(self, FQ(i)) for i in range(1024)] elif self == FixedTableTag.SignByte: - return [(self, i, (i >> 7) * 0xFF, 0) for i in range(256)] + return [FixedTableRow(self, FQ(i), FQ((i >> 7) * 0xFF)) for i in range(256)] elif self == FixedTableTag.BitwiseAnd: - return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs & rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseOr: - return [(self, lhs, rhs, lhs | rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs | rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseXor: - return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs ^ rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.ResponsibleOpcode: return [ - (self, execution_state, opcode, 0) + FixedTableRow(self, FQ(execution_state), FQ(opcode)) for execution_state in list(ExecutionState) for opcode in execution_state.responsible_opcode() ] elif self == FixedTableTag.InvalidOpcode: - return [(self, opcode, 0, 0) for opcode in invalid_opcodes()] + return [FixedTableRow(self, FQ(opcode)) for opcode in invalid_opcodes()] elif self == FixedTableTag.StateWriteOpcode: - return [(self, opcode, 0, 0) for opcode in state_write_opcodes()] + return [FixedTableRow(self, FQ(opcode)) for opcode in state_write_opcodes()] elif self == FixedTableTag.StackOverflow: return [ - (self, opcode, stack_pointer, 0) + FixedTableRow(self, FQ(opcode), FQ(stack_pointer)) for opcode, stack_pointer in stack_underflow_pairs() ] elif self == FixedTableTag.StackUnderflow: return [ - (self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs() + FixedTableRow(self, FQ(opcode), FQ(stack_pointer)) + for opcode, stack_pointer in stack_overflow_pairs() ] else: - ValueError("Unreacheable") + raise ValueError("Unreacheable") def range_table_tag(range: int) -> FixedTableTag: if range == 16: @@ -139,10 +150,13 @@ class TxContextFieldTag(IntEnum): CallData = auto() -class RW: +class RW(Enum): Read = False Write = True + def __int__(self): + return self.value + class RWTableTag(IntEnum): """ @@ -232,20 +246,84 @@ class CallContextFieldTag(IntEnum): StateWriteCounter = auto() +class WrongQueryKey(Exception): + def __init__(self, table_name: str, diff: Set[str]) -> None: + self.message = f"Lookup {table_name} with invalid keys {diff}" + + class LookupUnsatFailure(Exception): - def __init__(self, table_name: str, inputs: Tuple[int, ...]) -> None: + def __init__(self, table_name: str, inputs: Any) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is unsatisfied on inputs {inputs}" class LookupAmbiguousFailure(Exception): - def __init__( - self, table_name: str, inputs: Tuple[int, ...], matched_rows: Sequence[Tuple[int, ...]] - ) -> None: + def __init__(self, table_name: str, inputs: Any, matched_rows: Sequence[Any]) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is ambiguous on inputs {inputs}, ${len(matched_rows)} matched rows found: {matched_rows}" +class TableRow: + @classmethod + def validate_query(cls, table_name: str, query: Mapping[str, Any]): + names = set([field.name for field in fields(cls)]) + queried = set(query.keys()) + if not queried.issubset(names): + raise WrongQueryKey(table_name, queried - names) + + def match(self, query: Mapping[str, Any]) -> bool: + kv = asdict(self) + return all([int(kv[key]) == int(value) for key, value in query.items()]) + + +@dataclass(frozen=True) +class FixedTableRow(TableRow): + tag: FixedTableTag + value1: FQ + value2: FQ = field(default=FQ(0)) + value3: FQ = field(default=FQ(0)) + + +@dataclass(frozen=True) +class BlockTableRow(TableRow): + tag: BlockContextFieldTag + # meaningful only for HistoryHash, will be zero for other tags + block_number_or_zero: FQ + value: Union[FQ, RLC] + + +@dataclass(frozen=True) +class TxTableRow(TableRow): + tx_id: FQ + tag: TxContextFieldTag + # meaningful only for CallData, will be zero for other tags + call_data_index_or_zero: FQ + value: Union[FQ, RLC] + + +@dataclass(frozen=True) +class BytecodeTableRow(TableRow): + bytecode_hash: FQ + index: FQ + byte: FQ + is_code: FQ + + +@dataclass(frozen=True) +class RWTableRow(TableRow): + rw_counter: FQ + is_write: FQ + # key1 is also the tag + key1: RWTableTag + key2: FQ + key3: FQ + key4: FQ + value: FQ + value_prev: FQ + aux1: FQ + aux2: FQ + + class Tables: """ A collection of lookup tables used in EVM circuit. @@ -253,95 +331,109 @@ class Tables: _: Placeholder = Placeholder() - # Each row in FixedTable contains: - # - tag - # - value1 - # - value2 - # - value3 - fixed_table: Set[Array4] = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) - - # Each row in BlockTable contains: - # - tag - # - block_number_or_zero (meaningful only for HistoryHash, will be zero for other tags) - # - value - block_table: Set[Array3] - - # Each row in TxTable contains: - # - tx_id - # - tag - # - call_data_index_or_zero (meaningful only for CallData, will be zero for other tags) - # - value - tx_table: Set[Array4] - - # Each row in BytecodeTable contains: - # - bytecode_hash - # - index - # - byte - # - is_code - bytecode_table: Set[Array4] - - # Each row in RWTable contains: - # - rw_counter - # - is_write - # - key1 (tag) - # - key2 - # - key3 - # - key4 - # - value - # - value_prev - # - aux1 - # - aux2 - rw_table: Set[Array10] + fixed_table: Set[FixedTableRow] = set( + chain(*[tag.table_assignments() for tag in list(FixedTableTag)]) + ) + block_table: Set[BlockTableRow] + tx_table: Set[TxTableRow] + bytecode_table: Set[BytecodeTableRow] + rw_table: Set[RWTableRow] def __init__( self, - block_table: Set[Array3], - tx_table: Set[Array4], - bytecode_table: Set[Array4], - rw_table: Set[Array10], + block_table: Union[Set[Sequence[IntOrFQ]], Set[BlockTableRow]], + tx_table: Union[Set[Sequence[IntOrFQ]], Set[TxTableRow]], + bytecode_table: Union[Set[Sequence[IntOrFQ]], Set[BytecodeTableRow]], + rw_table: Union[Set[Sequence[IntOrFQ]], Set[RWTableRow]], ) -> None: - self.block_table = block_table - self.tx_table = tx_table - self.bytecode_table = bytecode_table - self.rw_table = rw_table - - def fixed_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("fixed_table", self.fixed_table, inputs) + self.block_table = set( + row if isinstance(row, BlockTableRow) else BlockTableRow(*row) # type: ignore # (BlockTableRow input args) + for row in block_table + ) + self.tx_table = set( + row if isinstance(row, TxTableRow) else TxTableRow(*row) # type: ignore # (TxTableRow input args) + for row in tx_table + ) + self.bytecode_table = set( + row if isinstance(row, BytecodeTableRow) else BytecodeTableRow(*row) # type: ignore # (BytecodeTableRow input args) + for row in bytecode_table + ) + self.rw_table = set( + row if isinstance(row, RWTableRow) else RWTableRow(*row) # type: ignore # (RWTableRow input args) + for row in rw_table + ) + + def fixed_lookup( + self, tag: FixedTableTag, value1: FQ, value2: FQ = None, value3: FQ = None + ) -> FixedTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "tag": tag, + "value1": value1, + "value2": value2, + "value3": value3, + } + return _lookup(FixedTableRow, self.fixed_table, query) + + def block_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ(0)) -> BlockTableRow: + query: Dict[str, Optional[IntOrFQ]] = {"tag": tag, "block_number_or_zero": index} + return _lookup(BlockTableRow, self.block_table, query) + + def tx_lookup(self, tx_id: FQ, field_tag: TxContextFieldTag, index: FQ) -> TxTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "tx_id": tx_id, + "tag": field_tag, + "call_data_index_or_zero": index, + } + return _lookup(TxTableRow, self.tx_table, query) + + def bytecode_lookup(self, bytecode_hash: FQ, index: FQ, is_code: FQ) -> BytecodeTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "bytecode_hash": bytecode_hash, + "index": index, + "is_code": is_code, + } + return _lookup(BytecodeTableRow, self.bytecode_table, query) + + def rw_lookup( + self, rw_counter: FQ, rw: RW, tag: RWTableTag, *other_queries: IntOrFQ + ) -> RWTableRow: + rest_keys = [ + "key2", + "key3", + "key4", + "value", + "value_prev", + "aux1", + "aux2", + ] + query: Dict[str, Optional[IntOrFQ]] = { + "rw_counter": rw_counter, + "is_write": int(rw.value), + "key1": tag, + **dict(zip(rest_keys, other_queries)), + } + return _lookup(RWTableRow, self.rw_table, query) - def block_lookup(self, inputs: Sequence[int]) -> Array3: - assert len(inputs) <= 3 - return _lookup("block_table", self.block_table, inputs) - def tx_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("tx_table", self.tx_table, inputs) +T = TypeVar("T", bound=TableRow) - def bytecode_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("bytecode_table", self.bytecode_table, inputs) - def rw_lookup(self, inputs: Sequence[int]) -> Array10: - assert len(inputs) <= 10 - return _lookup("rw_table", self.rw_table, inputs) +def _lookup( + table_cls: Type[T], + table: Set[T], + query: Mapping[str, Optional[IntOrFQ]], +) -> T: + # cleanup none value + query = {k: v for k, v in query.items() if v is not None} + table_name = table_cls.__name__ + table_cls.validate_query(table_name, query) -def _lookup( - table_name: str, - table: Set[Tuple[int, ...]], - inputs: Sequence[int], -) -> Tuple[int, ...]: - inputs = tuple(inputs) - inputs_len = len(inputs) - matched_rows = [] - - for row in table: - if inputs == row[:inputs_len]: - matched_rows.append(row) + matched_rows = [row for row in table if row.match(query)] if len(matched_rows) == 0: - raise LookupUnsatFailure(table_name, inputs) + raise LookupUnsatFailure(table_name, query) elif len(matched_rows) > 1: - raise LookupAmbiguousFailure(table_name, inputs, matched_rows) + raise LookupAmbiguousFailure(table_name, query, matched_rows) - return [v if isinstance(v, RLC) else FQ(v) for v in matched_rows[0]] + return matched_rows[0] diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 303d163e5..05bd6847f 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,20 +1,25 @@ from __future__ import annotations -from typing import Any, Dict, Iterator, NewType, Optional, Sequence +from typing import Any, Dict, Iterator, NewType, Optional, Sequence, List from functools import reduce from itertools import chain from ..util import ( + FQ, U64, U160, U256, - Array3, - Array4, RLC, keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, ) -from .table import BlockContextFieldTag, TxContextFieldTag +from .table import ( + BlockContextFieldTag, + TxContextFieldTag, + BytecodeTableRow, + TxTableRow, + BlockTableRow, +) from .opcode import get_push_size, Opcode @@ -38,12 +43,12 @@ class Block: def __init__( self, - coinbase: U160 = 0x10, - gas_limit: U64 = int(15e6), - number: U256 = 0, - timestamp: U64 = 0, - difficulty: U256 = 0, - base_fee: U256 = int(1e9), + coinbase: U160 = U160(0x10), + gas_limit: U64 = U64(15_000_000), + number: U256 = U256(0), + timestamp: U64 = U64(0), + difficulty: U256 = U256(0), + base_fee: U256 = U256(1_000_000_000), history_hashes: Sequence[U256] = [], ) -> None: assert len(history_hashes) <= min(256, number) @@ -56,16 +61,20 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, randomness: int) -> Sequence[Array3]: + def table_assignments(self, randomness: FQ) -> List[BlockTableRow]: return [ - (BlockContextFieldTag.Coinbase, 0, self.coinbase), - (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.Number, 0, RLC(self.number, randomness)), - (BlockContextFieldTag.Timestamp, 0, self.timestamp), - (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), - (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), + BlockTableRow(BlockContextFieldTag.Coinbase, FQ(0), FQ(self.coinbase)), + BlockTableRow(BlockContextFieldTag.GasLimit, FQ(0), FQ(self.gas_limit)), + BlockTableRow(BlockContextFieldTag.Number, FQ(0), RLC(self.number, randomness)), + BlockTableRow(BlockContextFieldTag.Timestamp, FQ(0), FQ(self.timestamp)), + BlockTableRow(BlockContextFieldTag.Difficulty, FQ(0), RLC(self.difficulty, randomness)), + BlockTableRow(BlockContextFieldTag.BaseFee, FQ(0), RLC(self.base_fee, randomness)), ] + [ - (BlockContextFieldTag.HistoryHash, self.number - idx - 1, RLC(history_hash, randomness)) + BlockTableRow( + BlockContextFieldTag.HistoryHash, + FQ(self.number - idx - 1), + RLC(history_hash, randomness), + ) for idx, history_hash in enumerate(reversed(self.history_hashes)) ] @@ -83,12 +92,12 @@ class Transaction: def __init__( self, id: int = 1, - nonce: U64 = 0, - gas: U64 = 21000, - gas_price: U256 = int(2e9), - caller_address: U160 = 0, + nonce: U64 = U64(0), + gas: U64 = U64(21_000), + gas_price: U256 = U256(2_000_000_000), + caller_address: U160 = U160(0), callee_address: Optional[U160] = None, - value: U256 = 0, + value: U256 = U256(0), call_data: bytes = bytes(), ) -> None: self.id = id @@ -114,21 +123,46 @@ def call_data_gas_cost(self) -> int: 0, ) - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: FQ) -> Iterator[TxTableRow]: return chain( [ - (self.id, TxContextFieldTag.Nonce, 0, self.nonce), - (self.id, TxContextFieldTag.Gas, 0, self.gas), - (self.id, TxContextFieldTag.GasPrice, 0, RLC(self.gas_price, randomness)), - (self.id, TxContextFieldTag.CallerAddress, 0, self.caller_address), - (self.id, TxContextFieldTag.CalleeAddress, 0, self.callee_address), - (self.id, TxContextFieldTag.IsCreate, 0, self.callee_address is None), - (self.id, TxContextFieldTag.Value, 0, RLC(self.value, randomness)), - (self.id, TxContextFieldTag.CallDataLength, 0, len(self.call_data)), - (self.id, TxContextFieldTag.CallDataGasCost, 0, self.call_data_gas_cost()), + TxTableRow(FQ(self.id), TxContextFieldTag.Nonce, FQ(0), FQ(self.nonce)), + TxTableRow(FQ(self.id), TxContextFieldTag.Gas, FQ(0), FQ(self.gas)), + TxTableRow( + FQ(self.id), + TxContextFieldTag.GasPrice, + FQ(0), + RLC(self.gas_price, randomness), + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.CallerAddress, FQ(0), FQ(self.caller_address) + ), + TxTableRow( + FQ(self.id), + TxContextFieldTag.CalleeAddress, + FQ(0), + FQ(self.callee_address or 0), + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.IsCreate, FQ(0), FQ(self.callee_address is None) + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.Value, FQ(0), RLC(self.value, randomness) + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.CallDataLength, FQ(0), FQ(len(self.call_data)) + ), + TxTableRow( + FQ(self.id), + TxContextFieldTag.CallDataGasCost, + FQ(0), + FQ(self.call_data_gas_cost()), + ), ], map( - lambda item: (self.id, TxContextFieldTag.CallData, item[0], item[1]), + lambda item: TxTableRow( + FQ(self.id), TxContextFieldTag.CallData, FQ(item[0]), FQ(item[1]) + ), enumerate(self.call_data), ), ) @@ -183,10 +217,10 @@ def push(self, value: Any, n_bytes: int = 32) -> Bytecode: return self - def hash(self) -> int: - return int.from_bytes(keccak256(self.code), "big") + def hash(self) -> U256: + return U256(int.from_bytes(keccak256(self.code), "big")) - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: FQ) -> Iterator[BytecodeTableRow]: class BytecodeIterator: idx: int push_data_left: int @@ -214,7 +248,7 @@ def __next__(self): self.idx += 1 - return (self.hash, idx, byte, is_code) + return (self.hash.value, idx, byte, is_code) return BytecodeIterator(RLC(self.hash(), randomness), self.code) @@ -231,9 +265,9 @@ class Account: def __init__( self, - address: U160 = 0, - nonce: U256 = 0, - balance: U256 = 0, + address: U160 = U160(0), + nonce: U256 = U256(0), + balance: U256 = U256(0), code: Optional[Bytecode] = None, storage: Optional[Storage] = None, ) -> None: @@ -241,7 +275,7 @@ def __init__( self.nonce = nonce self.balance = balance self.code = Bytecode() if code is None else code - self.storage = dict() if storage is None else storage + self.storage = Storage(dict()) if storage is None else storage def code_hash(self) -> U256: return self.code.hash() diff --git a/src/zkevm_specs/opcode/comparator.py b/src/zkevm_specs/opcode/comparator.py index 5449f1d93..d39166008 100644 --- a/src/zkevm_specs/opcode/comparator.py +++ b/src/zkevm_specs/opcode/comparator.py @@ -34,7 +34,7 @@ def compare( assert len(result) == 16 # Before we do any comparison, the previous result is "equal" - result = result[:] + [0] + result = list(result[:]) + [Sign(0)] for i in reversed(range(0, 32, 2)): a16 = a8s[i] + 256 * a8s[i + 1] diff --git a/src/zkevm_specs/opcode/lt_gt.py b/src/zkevm_specs/opcode/lt_gt.py index 845c38832..9091ceda6 100644 --- a/src/zkevm_specs/opcode/lt_gt.py +++ b/src/zkevm_specs/opcode/lt_gt.py @@ -1,6 +1,6 @@ from typing import Sequence -from ..encoding import is_circuit_code, U8, U256, u256_to_u8s +from ..encoding import is_circuit_code, U8 def lt_circuit( @@ -34,7 +34,7 @@ def lt_circuit( # lower 16 bytes # a[15:0] + c[15:0] == carry * 256^16 + b[15:0] lhs = 0 - rhs = carry + rhs = int(carry) for i in reversed(range(16)): lhs = lhs * 256 + a8s[i] + c8s[i] rhs = rhs * 256 + b8s[i] diff --git a/src/zkevm_specs/opcode/mload_mstore.py b/src/zkevm_specs/opcode/mload_mstore.py index fc203a2f5..7e0f7e606 100644 --- a/src/zkevm_specs/opcode/mload_mstore.py +++ b/src/zkevm_specs/opcode/mload_mstore.py @@ -15,14 +15,15 @@ def address_low( address: Sequence[U8], ) -> U64: - return sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) + _sum = sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) + return U64(_sum) @is_circuit_code def address_high( address: Sequence[U8], ) -> U256: - return sum(address[NUM_ADDRESS_BYTES_USED:]) + return U256(sum(address[NUM_ADDRESS_BYTES_USED:])) @is_circuit_code @@ -38,7 +39,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code @@ -46,8 +47,8 @@ def div( value: U256, divisor: U64, ) -> Tuple[U256, U256]: - quotient = value // divisor - remainder = value % divisor + quotient = U256(value // divisor) + remainder = U256(value % divisor) return (quotient, remainder) @@ -56,7 +57,7 @@ def lt( lhs: U256, rhs: U256, ) -> U256: - return lhs < rhs + return U256(lhs < rhs) @is_circuit_code @@ -71,7 +72,7 @@ def max( def memory_size( address: U64, ) -> U64: - return (address + 31) // 32 + return U64((address + 31) // 32) @is_circuit_code diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 2bf9d350f..9bbe511f0 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -17,11 +17,11 @@ def rand_fp() -> FQ: def rand_address() -> U160: - return rand_range(2**160) + return U160(rand_range(2**160)) def rand_word() -> U256: - return rand_range(2**256) + return U256(rand_range(2**256)) def rand_bytes(n_bytes: int = 32) -> bytes: diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index e91d14635..9a7d1fe8e 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -7,13 +7,12 @@ def _hash_fq(v: FQ) -> int: return hash(v.n) -FQ.__hash__ = _hash_fq +FQ.__hash__ = _hash_fq # type: ignore IntOrFQ = Union[int, FQ] -def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: int) -> FQ: +def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: FQ) -> FQ: ret = FQ.zero() - factor = FQ(factor) for byte in reversed(le_bytes): assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" ret = ret * factor + byte @@ -25,7 +24,7 @@ class RLC: value: FQ def __init__( - self, int_or_bytes: Union[IntOrFQ, bytes], randomness: int, n_bytes: int = 32 + self, int_or_bytes: Union[IntOrFQ, bytes], randomness: FQ, n_bytes: int = 32 ) -> None: if isinstance(int_or_bytes, int): assert ( @@ -49,7 +48,7 @@ def __init__( self.value = fp_linear_combine(self.le_bytes, randomness) - def __eq__(self, rhs: Union[int, FQ, RLC]): + def __eq__(self, rhs: Union[int, object]): if isinstance(rhs, (int, FQ)): return self.value == rhs if isinstance(rhs, RLC): @@ -63,5 +62,8 @@ def __hash__(self) -> int: def __repr__(self) -> str: return "RLC(%s)" % int.from_bytes(self.le_bytes, "little") + def __int__(self) -> int: + return int.from_bytes(self.le_bytes, "little") + def be_bytes(self) -> bytes: return bytes(reversed(self.le_bytes)) diff --git a/src/zkevm_specs/util/hash.py b/src/zkevm_specs/util/hash.py index 99d750cff..0e7056650 100644 --- a/src/zkevm_specs/util/hash.py +++ b/src/zkevm_specs/util/hash.py @@ -4,12 +4,11 @@ from .typing import U256 -def keccak256(data: Union[str, bytes, bytearray]) -> bytes: - if type(data) == str: - data = bytes.fromhex(data) +def keccak256(_data: Union[str, bytes, bytearray]) -> bytes: + data = bytes.fromhex(_data) if isinstance(_data, str) else _data return keccak.new(digest_bits=256).update(data).digest() -EMPTY_HASH: U256 = int.from_bytes(keccak256(""), "big") +EMPTY_HASH = U256(int.from_bytes(keccak256(""), "big")) EMPTY_CODE_HASH: U256 = EMPTY_HASH -EMPTY_TRIE_HASH: U256 = int.from_bytes(keccak256("80"), "big") +EMPTY_TRIE_HASH = U256(int.from_bytes(keccak256("80"), "big")) diff --git a/src/zkevm_specs/util/typing.py b/src/zkevm_specs/util/typing.py index 0ab5b8234..e8dba8780 100644 --- a/src/zkevm_specs/util/typing.py +++ b/src/zkevm_specs/util/typing.py @@ -1,48 +1,5 @@ -from typing import NewType, Tuple +from typing import NewType U64 = NewType("U64", int) U160 = NewType("U160", int) U256 = NewType("U256", int) - - -Array3 = NewType("Array3", Tuple[int, int, int]) -Array4 = NewType("Array4", Tuple[int, int, int, int]) -Array8 = NewType("Array8", Tuple[int, int, int, int, int, int, int, int]) -Array10 = NewType("Array10", Tuple[int, int, int, int, int, int, int, int, int, int]) -Array32 = NewType( - "Array32", - Tuple[ - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - ], -) diff --git a/tests/evm/test_gasprice.py b/tests/evm/test_gasprice.py index 8b98d0fd9..c2af3cfe5 100644 --- a/tests/evm/test_gasprice.py +++ b/tests/evm/test_gasprice.py @@ -51,7 +51,18 @@ def test_gasprice(gasprice: U256): 0, 0, ), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(gasprice, randomness), 0, 0, 0), + ( + 10, + RW.Write, + RWTableTag.Stack, + 1, + 1023, + 0, + RLC(gasprice, randomness).value, + 0, + 0, + 0, + ), ] ), ) diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index 835e1564c..c5f0148f7 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -92,13 +92,13 @@ def test_bytecode_invalid_hash_data(): # Change the hash on the first position invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash @@ -115,13 +115,13 @@ def test_bytecode_invalid_index(): # Start the index at 1 invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[idx] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) invalid_cell = invalid.rows[-1][0] - invalid_cell.value -= 1 + invalid_cell -= 1 invalid.rows[-1] = (invalid_cell, row[1], row[2], row[3]) verify(k, [invalid], randomness, False)