From 45a4f713275306ebb5940dae5a34b72b847df8c3 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Tue, 8 Feb 2022 18:58:03 +0800 Subject: [PATCH 01/15] low hanging fruit --- src/zkevm_specs/encoding/lookup.py | 2 +- src/zkevm_specs/evm/opcode.py | 14 ++++++++------ src/zkevm_specs/evm/table.py | 2 +- src/zkevm_specs/util/arithmetic.py | 2 +- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/zkevm_specs/encoding/lookup.py b/src/zkevm_specs/encoding/lookup.py index 774b892dd..808ab0d58 100644 --- a/src/zkevm_specs/encoding/lookup.py +++ b/src/zkevm_specs/encoding/lookup.py @@ -2,7 +2,7 @@ class LookupTable: - columns: Tuple[str] + columns: Tuple[str, ...] rows: Set[Tuple[int, ...]] def __init__(self, columns: Sequence[str]) -> None: diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 8f58cdc7e..bf7acd647 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Final, Dict, Sequence, Tuple, Union +from typing import Final, Dict, Sequence, Tuple, Union, Optional from ..util.param import * @@ -184,10 +184,12 @@ class OpcodeInfo: max_stack_pointer: int constant_gas_cost: int has_dynamic_gas: bool - pure_memory_expansion_info: Tuple[ - int, # offset stack_pointer_offset - int, # length stack_pointer_offset - int, # constant length + pure_memory_expansion_info: Optional[ + Tuple[ + int, # offset stack_pointer_offset + int, # length stack_pointer_offset + int, # constant length + ] ] def __init__( @@ -196,7 +198,7 @@ def __init__( max_stack_pointer: int, constant_gas_cost: int, has_dynamic_gas: bool = False, - pure_memory_expansion_info: Union[Tuple[int, int, int], None] = None, + pure_memory_expansion_info: Optional[Tuple[int, int, int]] = None, ) -> None: self.min_stack_pointer = min_stack_pointer self.max_stack_pointer = max_stack_pointer diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index c85c4e521..75b1e59fc 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -81,7 +81,7 @@ def table_assignments(self) -> Sequence[Array4]: (self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs() ] else: - ValueError("Unreacheable") + raise ValueError("Unreacheable") def range_table_tag(range: int) -> FixedTableTag: if range == 16: diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index e91d14635..6ea4861a8 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -49,7 +49,7 @@ def __init__( self.value = fp_linear_combine(self.le_bytes, randomness) - def __eq__(self, rhs: Union[int, FQ, RLC]): + def __eq__(self, rhs: Union[int, object]): if isinstance(rhs, (int, FQ)): return self.value == rhs if isinstance(rhs, RLC): From d6a5ea2ac454ab6737d38b302f1e207ab150a594 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Tue, 8 Feb 2022 19:03:37 +0800 Subject: [PATCH 02/15] encoding/utils and util/ --- src/zkevm_specs/encoding/utils.py | 8 ++++---- src/zkevm_specs/util/__init__.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/zkevm_specs/encoding/utils.py b/src/zkevm_specs/encoding/utils.py index 9c0db0cd2..302d76439 100644 --- a/src/zkevm_specs/encoding/utils.py +++ b/src/zkevm_specs/encoding/utils.py @@ -15,19 +15,19 @@ def wrapper(*args, **kargs): def u256_to_u8s(x: U256) -> Tuple[U8, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 8 * i) & 0xFF for i in range(32)) + return tuple(U8((x >> 8 * i) & 0xFF) for i in range(32)) def u256_to_u64s(x: U256) -> Tuple[U64, ...]: assert 0 <= x < 2**256, "expect x is unsigned 256 bits" - return tuple((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF for i in range(4)) + return tuple(U64((x >> 64 * i) & 0xFFFFFFFFFFFFFFFF) for i in range(4)) def u8s_to_u256(xs: Sequence[U8]) -> U256: assert len(xs) == 32 for u8 in xs: assert 0 <= u8 <= 255 - return sum(x * (2 ** (8 * i)) for i, x in enumerate(xs)) + return U256(sum(x * (2 ** (8 * i)) for i, x in enumerate(xs))) # [u8;32]->[u64;4] @@ -37,5 +37,5 @@ def u8s_to_u64s(xs: Sequence[U8]) -> Tuple[U64, ...]: A = [u64_0] * 4 # A = A3A2A1A0 for i in range(4): for j in range(8): - A[i] += U64(xs[j + 8 * i] * (2 ** (8 * j))) + A[i] += xs[j + 8 * i] * (2 ** (8 * j)) return tuple(A) diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 2bf9d350f..9bbe511f0 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -17,11 +17,11 @@ def rand_fp() -> FQ: def rand_address() -> U160: - return rand_range(2**160) + return U160(rand_range(2**160)) def rand_word() -> U256: - return rand_range(2**256) + return U256(rand_range(2**256)) def rand_bytes(n_bytes: int = 32) -> bytes: From 1b5be1be932898e85880613780d1e2e31ad61cd1 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Fri, 11 Feb 2022 18:14:59 +0800 Subject: [PATCH 03/15] add type check command and ci check --- .github/workflows/python-package.yml | 2 ++ Makefile | 3 +++ 2 files changed, 5 insertions(+) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 5a0007af9..22f778391 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -27,5 +27,7 @@ jobs: make install - name: Lint run: make lint + - name: Type check + run: make type - name: Test with pytest run: make test diff --git a/Makefile b/Makefile index d90c6de50..32c551cdd 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ lint: ## Check whether the code is formated correctly black . --check mdformat specs/ --number --check +type: ## Check the typing of the Python code + mypy src + test: ## Run tests pytest From 6cfdad0417f110882ed99398e32675931d20f46a Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Fri, 11 Feb 2022 18:33:09 +0800 Subject: [PATCH 04/15] fix util and lt_gt --- src/zkevm_specs/opcode/lt_gt.py | 4 ++-- src/zkevm_specs/util/arithmetic.py | 6 +++--- src/zkevm_specs/util/hash.py | 9 ++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/zkevm_specs/opcode/lt_gt.py b/src/zkevm_specs/opcode/lt_gt.py index 845c38832..9091ceda6 100644 --- a/src/zkevm_specs/opcode/lt_gt.py +++ b/src/zkevm_specs/opcode/lt_gt.py @@ -1,6 +1,6 @@ from typing import Sequence -from ..encoding import is_circuit_code, U8, U256, u256_to_u8s +from ..encoding import is_circuit_code, U8 def lt_circuit( @@ -34,7 +34,7 @@ def lt_circuit( # lower 16 bytes # a[15:0] + c[15:0] == carry * 256^16 + b[15:0] lhs = 0 - rhs = carry + rhs = int(carry) for i in reversed(range(16)): lhs = lhs * 256 + a8s[i] + c8s[i] rhs = rhs * 256 + b8s[i] diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index 6ea4861a8..4ae4e6b1c 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -7,13 +7,13 @@ def _hash_fq(v: FQ) -> int: return hash(v.n) -FQ.__hash__ = _hash_fq +FQ.__hash__ = _hash_fq # type: ignore IntOrFQ = Union[int, FQ] -def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: int) -> FQ: +def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], _factor: int) -> FQ: ret = FQ.zero() - factor = FQ(factor) + factor = FQ(_factor) for byte in reversed(le_bytes): assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" ret = ret * factor + byte diff --git a/src/zkevm_specs/util/hash.py b/src/zkevm_specs/util/hash.py index 99d750cff..0e7056650 100644 --- a/src/zkevm_specs/util/hash.py +++ b/src/zkevm_specs/util/hash.py @@ -4,12 +4,11 @@ from .typing import U256 -def keccak256(data: Union[str, bytes, bytearray]) -> bytes: - if type(data) == str: - data = bytes.fromhex(data) +def keccak256(_data: Union[str, bytes, bytearray]) -> bytes: + data = bytes.fromhex(_data) if isinstance(_data, str) else _data return keccak.new(digest_bits=256).update(data).digest() -EMPTY_HASH: U256 = int.from_bytes(keccak256(""), "big") +EMPTY_HASH = U256(int.from_bytes(keccak256(""), "big")) EMPTY_CODE_HASH: U256 = EMPTY_HASH -EMPTY_TRIE_HASH: U256 = int.from_bytes(keccak256("80"), "big") +EMPTY_TRIE_HASH = U256(int.from_bytes(keccak256("80"), "big")) From 3e42f50bfdbe47ed7f14a6ad46cc1040a84e070e Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Fri, 11 Feb 2022 18:42:01 +0800 Subject: [PATCH 05/15] fix bytecode --- src/zkevm_specs/bytecode.py | 2 +- src/zkevm_specs/encoding/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 8b7ef60b4..d198dab3b 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -24,7 +24,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code diff --git a/src/zkevm_specs/encoding/utils.py b/src/zkevm_specs/encoding/utils.py index 302d76439..1b64996ee 100644 --- a/src/zkevm_specs/encoding/utils.py +++ b/src/zkevm_specs/encoding/utils.py @@ -1,8 +1,8 @@ -from typing import Sequence, Tuple, List +from typing import Sequence, Tuple from .typing import U8, U256, U64 -def is_circuit_code(func) -> object: +def is_circuit_code(func): """ A no-op decorator just to mark the function """ From 87fd9643ff15d7f37a1ba0fe50bb50b31d480a14 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Fri, 11 Feb 2022 18:47:25 +0800 Subject: [PATCH 06/15] comparator and mload_mstore --- src/zkevm_specs/opcode/comparator.py | 2 +- src/zkevm_specs/opcode/mload_mstore.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/zkevm_specs/opcode/comparator.py b/src/zkevm_specs/opcode/comparator.py index 5449f1d93..d39166008 100644 --- a/src/zkevm_specs/opcode/comparator.py +++ b/src/zkevm_specs/opcode/comparator.py @@ -34,7 +34,7 @@ def compare( assert len(result) == 16 # Before we do any comparison, the previous result is "equal" - result = result[:] + [0] + result = list(result[:]) + [Sign(0)] for i in reversed(range(0, 32, 2)): a16 = a8s[i] + 256 * a8s[i + 1] diff --git a/src/zkevm_specs/opcode/mload_mstore.py b/src/zkevm_specs/opcode/mload_mstore.py index fc203a2f5..5bf70d977 100644 --- a/src/zkevm_specs/opcode/mload_mstore.py +++ b/src/zkevm_specs/opcode/mload_mstore.py @@ -15,14 +15,14 @@ def address_low( address: Sequence[U8], ) -> U64: - return sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) - + _sum = sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) + return U64(_sum) @is_circuit_code def address_high( address: Sequence[U8], ) -> U256: - return sum(address[NUM_ADDRESS_BYTES_USED:]) + return U256(sum(address[NUM_ADDRESS_BYTES_USED:])) @is_circuit_code @@ -38,7 +38,7 @@ def select( when_true: U256, when_false: U256, ) -> U256: - return selector * when_true + (1 - selector) * when_false + return U256(selector * when_true + (1 - selector) * when_false) @is_circuit_code @@ -46,8 +46,8 @@ def div( value: U256, divisor: U64, ) -> Tuple[U256, U256]: - quotient = value // divisor - remainder = value % divisor + quotient = U256(value // divisor) + remainder = U256(value % divisor) return (quotient, remainder) @@ -56,7 +56,7 @@ def lt( lhs: U256, rhs: U256, ) -> U256: - return lhs < rhs + return U256(lhs < rhs) @is_circuit_code @@ -71,7 +71,7 @@ def max( def memory_size( address: U64, ) -> U64: - return (address + 31) // 32 + return U64((address + 31) // 32) @is_circuit_code From a749160bb465630f874226b69e17a1532ba0c5d3 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Fri, 11 Feb 2022 18:54:23 +0800 Subject: [PATCH 07/15] add mypy version --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index e77608300..81a802224 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,3 +31,4 @@ test = lint = black >= 22.1.0 mdformat >= 0.7.13 + mypy >= 0.931 From 0e56c73caec18267b95ff74b48e2f619a908fd0e Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Mon, 14 Feb 2022 17:44:56 +0800 Subject: [PATCH 08/15] Fix table.py Replace ArrayX with dataclass --- src/zkevm_specs/evm/instruction.py | 16 +- src/zkevm_specs/evm/table.py | 248 ++++++++++++++++++----------- src/zkevm_specs/evm/typing.py | 16 +- src/zkevm_specs/util/typing.py | 45 +----- 4 files changed, 171 insertions(+), 154 deletions(-) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 6cfd3c63f..fb3ccee2c 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -3,8 +3,6 @@ from typing import Optional, Sequence, Tuple, Union, Mapping from ..util import ( - Array4, - Array10, FQ, IntOrFQ, RLC, @@ -27,6 +25,8 @@ TxContextFieldTag, RW, RWTableTag, + RWTableRow, + FixedTableRow, ) @@ -325,8 +325,8 @@ def range_check(self, value: FQ, n_bytes: int) -> bytes: except OverflowError: raise ConstraintUnsatFailure(f"Value {value} has too many bytes to fit {n_bytes} bytes") - def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> Array4: - return self.tables.fixed_lookup([tag] + inputs) + def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> FixedTableRow: + return self.tables.fixed_lookup(tag, *inputs) def block_context_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ.zero()) -> FQ: return self.tables.block_lookup([tag, index])[2] @@ -361,9 +361,7 @@ def opcode_lookup_at(self, index: FQ, is_code: bool) -> FQ: else: return self.bytecode_lookup(self.curr.code_source, index, is_code) - def rw_lookup( - self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None - ) -> Array10: + def rw_lookup(self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None) -> RWTableRow: if rw_counter is None: rw_counter = self.curr.rw_counter + self.rw_counter_offset self.rw_counter_offset += 1 @@ -374,7 +372,7 @@ def state_write_only_persistent( tag: RWTableTag, inputs: Sequence[int], is_persistent: bool, - ) -> Array10: + ) -> RWTableRow: assert tag.write_only_persistent() if is_persistent: @@ -389,7 +387,7 @@ def state_write_with_reversion( is_persistent: bool, rw_counter_end_of_reversion: int, state_write_counter: Optional[int] = None, - ) -> Array10: + ) -> RWTableRow: assert tag.write_with_reversion() row = self.rw_lookup(RW.Write, tag, inputs) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 75b1e59fc..efbf4d740 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,9 +1,11 @@ from __future__ import annotations -from typing import Sequence, Set, Tuple -from enum import IntEnum, auto +from typing import Mapping, Sequence, Set, List, TypeVar, Any, Type, Optional, Dict +from enum import IntEnum, auto, Enum from itertools import chain, product +from dataclasses import dataclass, field, asdict, fields -from ..util import FQ, RLC, Array3, Array4, Array10 + +from ..util import FQ, IntOrFQ from .execution_state import ExecutionState from .opcode import ( invalid_opcodes, @@ -40,45 +42,55 @@ class FixedTableTag(IntEnum): StackOverflow = auto() # opcode, stack_pointer, 0 StackUnderflow = auto() # opcode, stack_pointer, 0 - def table_assignments(self) -> Sequence[Array4]: + def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range16: - return [(self, i, 0, 0) for i in range(16)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(16)] elif self == FixedTableTag.Range32: - return [(self, i, 0, 0) for i in range(32)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(32)] elif self == FixedTableTag.Range64: - return [(self, i, 0, 0) for i in range(64)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(64)] elif self == FixedTableTag.Range256: - return [(self, i, 0, 0) for i in range(256)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(256)] elif self == FixedTableTag.Range512: - return [(self, i, 0, 0) for i in range(512)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(512)] elif self == FixedTableTag.Range1024: - return [(self, i, 0, 0) for i in range(1024)] + return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(1024)] elif self == FixedTableTag.SignByte: - return [(self, i, (i >> 7) * 0xFF, 0) for i in range(256)] + return [FixedTableRow(FQ(self.value), FQ(i), FQ((i >> 7) * 0xFF)) for i in range(256)] elif self == FixedTableTag.BitwiseAnd: - return [(self, lhs, rhs, lhs & rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs & rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseOr: - return [(self, lhs, rhs, lhs | rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs | rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.BitwiseXor: - return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] + return [ + FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs ^ rhs)) + for lhs, rhs in product(range(256), range(256)) + ] elif self == FixedTableTag.ResponsibleOpcode: return [ - (self, execution_state, opcode, 0) + FixedTableRow(FQ(self.value), FQ(execution_state), FQ(opcode)) for execution_state in list(ExecutionState) for opcode in execution_state.responsible_opcode() ] elif self == FixedTableTag.InvalidOpcode: - return [(self, opcode, 0, 0) for opcode in invalid_opcodes()] + return [FixedTableRow(FQ(self.value), FQ(opcode)) for opcode in invalid_opcodes()] elif self == FixedTableTag.StateWriteOpcode: - return [(self, opcode, 0, 0) for opcode in state_write_opcodes()] + return [FixedTableRow(FQ(self.value), FQ(opcode)) for opcode in state_write_opcodes()] elif self == FixedTableTag.StackOverflow: return [ - (self, opcode, stack_pointer, 0) + FixedTableRow(FQ(self.value), FQ(opcode), FQ(stack_pointer)) for opcode, stack_pointer in stack_underflow_pairs() ] elif self == FixedTableTag.StackUnderflow: return [ - (self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs() + FixedTableRow(FQ(self.value), FQ(opcode), FQ(stack_pointer)) + for opcode, stack_pointer in stack_overflow_pairs() ] else: raise ValueError("Unreacheable") @@ -139,7 +151,7 @@ class TxContextFieldTag(IntEnum): CallData = auto() -class RW: +class RW(Enum): Read = False Write = True @@ -232,6 +244,11 @@ class CallContextFieldTag(IntEnum): StateWriteCounter = auto() +class WrongQueryKey(Exception): + def __init__(self, table_name: str, diff: Set[str]) -> None: + self.message = f"Lookup {table_name} with invalid keys {diff}" + + class LookupUnsatFailure(Exception): def __init__(self, table_name: str, inputs: Tuple[int, ...]) -> None: self.inputs = inputs @@ -239,13 +256,72 @@ def __init__(self, table_name: str, inputs: Tuple[int, ...]) -> None: class LookupAmbiguousFailure(Exception): - def __init__( - self, table_name: str, inputs: Tuple[int, ...], matched_rows: Sequence[Tuple[int, ...]] - ) -> None: + def __init__(self, table_name: str, inputs: Tuple[int, ...], matched_rows: Sequence[Any]) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is ambiguous on inputs {inputs}, ${len(matched_rows)} matched rows found: {matched_rows}" +class TableRow: + @classmethod + def validate_query(cls, table_name: str, query: Mapping[str, Any]): + names = set([field.name for field in fields(cls)]) + queried = set(query.keys()) + if not queried.issubset(names): + raise WrongQueryKey(table_name, queried - names) + + def match(self, query: Mapping[str, Any]) -> bool: + kv = asdict(self) + return all([int(kv[key]) == int(value) for key, value in query.items()]) + + +@dataclass(frozen=True) +class FixedTableRow(TableRow): + tag: FixedTableTag + value1: FQ + value2: FQ = field(default=FQ(0)) + value3: FQ = field(default=FQ(0)) + + +@dataclass(frozen=True) +class BlockTableRow(TableRow): + tag: BlockContextFieldTag + # meaningful only for HistoryHash, will be zero for other tags + block_number_or_zero: FQ + value: FQ + + +@dataclass(frozen=True) +class TxTableRow(TableRow): + tx_id: FQ + tag: FQ + # meaningful only for CallData, will be zero for other tags + call_data_index_or_zero: FQ + value: FQ + + +@dataclass(frozen=True) +class BytecodeTableRow(TableRow): + bytecode_hash: FQ + index: FQ + byte: FQ + is_code: FQ + + +@dataclass(frozen=True) +class RWTableRow(TableRow): + rw_counter: FQ + is_write: FQ + # key1 is also the tag + key1: FQ + key2: FQ + key3: FQ + key4: FQ + value: FQ + value_prev: FQ + aux1: FQ + aux2: FQ + + class Tables: """ A collection of lookup tables used in EVM circuit. @@ -253,95 +329,77 @@ class Tables: _: Placeholder = Placeholder() - # Each row in FixedTable contains: - # - tag - # - value1 - # - value2 - # - value3 - fixed_table: Set[Array4] = set(chain(*[tag.table_assignments() for tag in list(FixedTableTag)])) - - # Each row in BlockTable contains: - # - tag - # - block_number_or_zero (meaningful only for HistoryHash, will be zero for other tags) - # - value - block_table: Set[Array3] - - # Each row in TxTable contains: - # - tx_id - # - tag - # - call_data_index_or_zero (meaningful only for CallData, will be zero for other tags) - # - value - tx_table: Set[Array4] - - # Each row in BytecodeTable contains: - # - bytecode_hash - # - index - # - byte - # - is_code - bytecode_table: Set[Array4] - - # Each row in RWTable contains: - # - rw_counter - # - is_write - # - key1 (tag) - # - key2 - # - key3 - # - key4 - # - value - # - value_prev - # - aux1 - # - aux2 - rw_table: Set[Array10] + fixed_table: Set[FixedTableRow] = set( + chain(*[tag.table_assignments() for tag in list(FixedTableTag)]) + ) + block_table: Set[BlockTableRow] + tx_table: Set[TxTableRow] + bytecode_table: Set[BytecodeTableRow] + rw_table: Set[RWTableRow] def __init__( self, - block_table: Set[Array3], - tx_table: Set[Array4], - bytecode_table: Set[Array4], - rw_table: Set[Array10], + block_table: Set[BlockTableRow], + tx_table: Set[TxTableRow], + bytecode_table: Set[BytecodeTableRow], + rw_table: Set[RWTableRow], ) -> None: self.block_table = block_table self.tx_table = tx_table self.bytecode_table = bytecode_table self.rw_table = rw_table - def fixed_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("fixed_table", self.fixed_table, inputs) + def fixed_lookup( + self, tag: FixedTableTag, value1: FQ, value2: FQ = None, value3: FQ = None + ) -> FixedTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "tag": tag, + "value1": value1, + "value2": value2, + "value3": value3, + } + return _lookup(FixedTableRow, self.fixed_table, query) + + def block_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ(0)) -> BlockTableRow: + query: Dict[str, Optional[IntOrFQ]] = {"tag": tag, "block_number_or_zero": index} + return _lookup(BlockTableRow, self.block_table, query) - def block_lookup(self, inputs: Sequence[int]) -> Array3: - assert len(inputs) <= 3 - return _lookup("block_table", self.block_table, inputs) + def tx_lookup(self, tx_id: int, field_tag: TxContextFieldTag, index: FQ) -> TxTableRow: + query: Dict[str, Optional[IntOrFQ]] = {"tx_id": tx_id, "tag": field_tag, "call_data_index_or_zero": index} + return _lookup(TxTableRow, self.tx_table, query) - def tx_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("tx_table", self.tx_table, inputs) + def bytecode_lookup(self, bytecode_hash: FQ, index: FQ, is_code: FQ) -> BytecodeTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "bytecode_hash": bytecode_hash, + "index": index, + "is_code": is_code, + } + return _lookup(BytecodeTableRow, self.bytecode_table, query) - def bytecode_lookup(self, inputs: Sequence[int]) -> Array4: - assert len(inputs) <= 4 - return _lookup("bytecode_table", self.bytecode_table, inputs) + def rw_lookup(self, rw_counter: FQ, rw: RW, tag: RWTableTag, **other_queries: IntOrFQ) -> RWTableRow: + query: Dict[str, Optional[IntOrFQ]] = {"rw_counter": rw_counter, "rw": int(rw.value), "tag": tag, **other_queries} + return _lookup(RWTableRow, self.rw_table, query) - def rw_lookup(self, inputs: Sequence[int]) -> Array10: - assert len(inputs) <= 10 - return _lookup("rw_table", self.rw_table, inputs) + +T = TypeVar("T", bound=TableRow) def _lookup( - table_name: str, - table: Set[Tuple[int, ...]], - inputs: Sequence[int], -) -> Tuple[int, ...]: - inputs = tuple(inputs) - inputs_len = len(inputs) - matched_rows = [] - - for row in table: - if inputs == row[:inputs_len]: - matched_rows.append(row) + table_cls: Type[T], + table: Set[T], + query: Mapping[str, Optional[IntOrFQ]], +) -> T: + # cleanup none value + query = {k: v for k, v in query.items() if v is not None} + + table_name = table_cls.__name__ + table_cls.validate_query(table_name, query) + + matched_rows = [row for row in table if row.match(query)] if len(matched_rows) == 0: - raise LookupUnsatFailure(table_name, inputs) + raise LookupUnsatFailure(table_name, query) elif len(matched_rows) > 1: - raise LookupAmbiguousFailure(table_name, inputs, matched_rows) + raise LookupAmbiguousFailure(table_name, query, matched_rows) - return [v if isinstance(v, RLC) else FQ(v) for v in matched_rows[0]] + return matched_rows[0] diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 303d163e5..2906ee5b3 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -7,14 +7,18 @@ U64, U160, U256, - Array3, - Array4, RLC, keccak256, GAS_COST_TX_CALL_DATA_PER_NON_ZERO_BYTE, GAS_COST_TX_CALL_DATA_PER_ZERO_BYTE, ) -from .table import BlockContextFieldTag, TxContextFieldTag +from .table import ( + BlockContextFieldTag, + TxContextFieldTag, + BytecodeTableRow, + TxTableRow, + BlockTableRow, +) from .opcode import get_push_size, Opcode @@ -56,7 +60,7 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, randomness: int) -> Sequence[Array3]: + def table_assignments(self, randomness: int) -> Sequence[BlockTableRow]: return [ (BlockContextFieldTag.Coinbase, 0, self.coinbase), (BlockContextFieldTag.GasLimit, 0, self.gas_limit), @@ -114,7 +118,7 @@ def call_data_gas_cost(self) -> int: 0, ) - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: int) -> Iterator[TxTableRow]: return chain( [ (self.id, TxContextFieldTag.Nonce, 0, self.nonce), @@ -186,7 +190,7 @@ def push(self, value: Any, n_bytes: int = 32) -> Bytecode: def hash(self) -> int: return int.from_bytes(keccak256(self.code), "big") - def table_assignments(self, randomness: int) -> Iterator[Array4]: + def table_assignments(self, randomness: int) -> Iterator[BytecodeTableRow]: class BytecodeIterator: idx: int push_data_left: int diff --git a/src/zkevm_specs/util/typing.py b/src/zkevm_specs/util/typing.py index 0ab5b8234..e8dba8780 100644 --- a/src/zkevm_specs/util/typing.py +++ b/src/zkevm_specs/util/typing.py @@ -1,48 +1,5 @@ -from typing import NewType, Tuple +from typing import NewType U64 = NewType("U64", int) U160 = NewType("U160", int) U256 = NewType("U256", int) - - -Array3 = NewType("Array3", Tuple[int, int, int]) -Array4 = NewType("Array4", Tuple[int, int, int, int]) -Array8 = NewType("Array8", Tuple[int, int, int, int, int, int, int, int]) -Array10 = NewType("Array10", Tuple[int, int, int, int, int, int, int, int, int, int]) -Array32 = NewType( - "Array32", - Tuple[ - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - int, - ], -) From ba1874b1005ae4879985ac3591a13a9e40deec59 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Tue, 15 Feb 2022 17:10:56 +0800 Subject: [PATCH 09/15] fix step --- src/zkevm_specs/evm/step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index 28bcf2271..6914eb000 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -50,7 +50,7 @@ def __init__( call_id: int = 0, is_root: bool = False, is_create: bool = False, - code_source: int = 0, + code_source: RLC = RLC(0, 0), program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, From 387a3347295e2c12258eb29b89d743becef0ff93 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Tue, 15 Feb 2022 18:22:51 +0800 Subject: [PATCH 10/15] fix typing.py --- src/zkevm_specs/evm/table.py | 34 ++++---- src/zkevm_specs/evm/typing.py | 103 ++++++++++++++++--------- src/zkevm_specs/opcode/mload_mstore.py | 1 + src/zkevm_specs/util/arithmetic.py | 2 +- 4 files changed, 86 insertions(+), 54 deletions(-) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index efbf4d740..124671a6a 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -44,52 +44,52 @@ class FixedTableTag(IntEnum): def table_assignments(self) -> List[FixedTableRow]: if self == FixedTableTag.Range16: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(16)] + return [FixedTableRow(self, FQ(i)) for i in range(16)] elif self == FixedTableTag.Range32: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(32)] + return [FixedTableRow(self, FQ(i)) for i in range(32)] elif self == FixedTableTag.Range64: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(64)] + return [FixedTableRow(self, FQ(i)) for i in range(64)] elif self == FixedTableTag.Range256: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(256)] + return [FixedTableRow(self, FQ(i)) for i in range(256)] elif self == FixedTableTag.Range512: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(512)] + return [FixedTableRow(self, FQ(i)) for i in range(512)] elif self == FixedTableTag.Range1024: - return [FixedTableRow(FQ(self.value), FQ(i)) for i in range(1024)] + return [FixedTableRow(self, FQ(i)) for i in range(1024)] elif self == FixedTableTag.SignByte: - return [FixedTableRow(FQ(self.value), FQ(i), FQ((i >> 7) * 0xFF)) for i in range(256)] + return [FixedTableRow(self, FQ(i), FQ((i >> 7) * 0xFF)) for i in range(256)] elif self == FixedTableTag.BitwiseAnd: return [ - FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs & rhs)) + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs & rhs)) for lhs, rhs in product(range(256), range(256)) ] elif self == FixedTableTag.BitwiseOr: return [ - FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs | rhs)) + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs | rhs)) for lhs, rhs in product(range(256), range(256)) ] elif self == FixedTableTag.BitwiseXor: return [ - FixedTableRow(FQ(self.value), FQ(lhs), FQ(rhs), FQ(lhs ^ rhs)) + FixedTableRow(self, FQ(lhs), FQ(rhs), FQ(lhs ^ rhs)) for lhs, rhs in product(range(256), range(256)) ] elif self == FixedTableTag.ResponsibleOpcode: return [ - FixedTableRow(FQ(self.value), FQ(execution_state), FQ(opcode)) + FixedTableRow(self, FQ(execution_state), FQ(opcode)) for execution_state in list(ExecutionState) for opcode in execution_state.responsible_opcode() ] elif self == FixedTableTag.InvalidOpcode: - return [FixedTableRow(FQ(self.value), FQ(opcode)) for opcode in invalid_opcodes()] + return [FixedTableRow(self, FQ(opcode)) for opcode in invalid_opcodes()] elif self == FixedTableTag.StateWriteOpcode: - return [FixedTableRow(FQ(self.value), FQ(opcode)) for opcode in state_write_opcodes()] + return [FixedTableRow(self, FQ(opcode)) for opcode in state_write_opcodes()] elif self == FixedTableTag.StackOverflow: return [ - FixedTableRow(FQ(self.value), FQ(opcode), FQ(stack_pointer)) + FixedTableRow(self, FQ(opcode), FQ(stack_pointer)) for opcode, stack_pointer in stack_underflow_pairs() ] elif self == FixedTableTag.StackUnderflow: return [ - FixedTableRow(FQ(self.value), FQ(opcode), FQ(stack_pointer)) + FixedTableRow(self, FQ(opcode), FQ(stack_pointer)) for opcode, stack_pointer in stack_overflow_pairs() ] else: @@ -293,7 +293,7 @@ class BlockTableRow(TableRow): @dataclass(frozen=True) class TxTableRow(TableRow): tx_id: FQ - tag: FQ + tag: TxContextFieldTag # meaningful only for CallData, will be zero for other tags call_data_index_or_zero: FQ value: FQ @@ -312,7 +312,7 @@ class RWTableRow(TableRow): rw_counter: FQ is_write: FQ # key1 is also the tag - key1: FQ + key1: RWTableTag key2: FQ key3: FQ key4: FQ diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 2906ee5b3..528497281 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, Dict, Iterator, NewType, Optional, Sequence +from typing import Any, Dict, Iterator, NewType, Optional, Sequence, List from functools import reduce from itertools import chain from ..util import ( + FQ, U64, U160, U256, @@ -42,12 +43,12 @@ class Block: def __init__( self, - coinbase: U160 = 0x10, - gas_limit: U64 = int(15e6), - number: U256 = 0, - timestamp: U64 = 0, - difficulty: U256 = 0, - base_fee: U256 = int(1e9), + coinbase: U160 = U160(0x10), + gas_limit: U64 = U64(15_000_000), + number: U256 = U256(0), + timestamp: U64 = U64(0), + difficulty: U256 = U256(0), + base_fee: U256 = U256(1_000_000_000), history_hashes: Sequence[U256] = [], ) -> None: assert len(history_hashes) <= min(256, number) @@ -60,16 +61,24 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, randomness: int) -> Sequence[BlockTableRow]: + def table_assignments(self, randomness: int) -> List[BlockTableRow]: return [ - (BlockContextFieldTag.Coinbase, 0, self.coinbase), - (BlockContextFieldTag.GasLimit, 0, self.gas_limit), - (BlockContextFieldTag.Number, 0, RLC(self.number, randomness)), - (BlockContextFieldTag.Timestamp, 0, self.timestamp), - (BlockContextFieldTag.Difficulty, 0, RLC(self.difficulty, randomness)), - (BlockContextFieldTag.BaseFee, 0, RLC(self.base_fee, randomness)), + BlockTableRow(BlockContextFieldTag.Coinbase, FQ(0), FQ(self.coinbase)), + BlockTableRow(BlockContextFieldTag.GasLimit, FQ(0), FQ(self.gas_limit)), + BlockTableRow(BlockContextFieldTag.Number, FQ(0), RLC(self.number, randomness).value), + BlockTableRow(BlockContextFieldTag.Timestamp, FQ(0), FQ(self.timestamp)), + BlockTableRow( + BlockContextFieldTag.Difficulty, FQ(0), RLC(self.difficulty, randomness).value + ), + BlockTableRow( + BlockContextFieldTag.BaseFee, FQ(0), RLC(self.base_fee, randomness).value + ), ] + [ - (BlockContextFieldTag.HistoryHash, self.number - idx - 1, RLC(history_hash, randomness)) + BlockTableRow( + BlockContextFieldTag.HistoryHash, + FQ(self.number - idx - 1), + RLC(history_hash, randomness).value, + ) for idx, history_hash in enumerate(reversed(self.history_hashes)) ] @@ -87,12 +96,12 @@ class Transaction: def __init__( self, id: int = 1, - nonce: U64 = 0, - gas: U64 = 21000, - gas_price: U256 = int(2e9), - caller_address: U160 = 0, + nonce: U64 = U64(0), + gas: U64 = U64(21_000), + gas_price: U256 = U256(2_000_000_000), + caller_address: U160 = U160(0), callee_address: Optional[U160] = None, - value: U256 = 0, + value: U256 = U256(0), call_data: bytes = bytes(), ) -> None: self.id = id @@ -121,18 +130,40 @@ def call_data_gas_cost(self) -> int: def table_assignments(self, randomness: int) -> Iterator[TxTableRow]: return chain( [ - (self.id, TxContextFieldTag.Nonce, 0, self.nonce), - (self.id, TxContextFieldTag.Gas, 0, self.gas), - (self.id, TxContextFieldTag.GasPrice, 0, RLC(self.gas_price, randomness)), - (self.id, TxContextFieldTag.CallerAddress, 0, self.caller_address), - (self.id, TxContextFieldTag.CalleeAddress, 0, self.callee_address), - (self.id, TxContextFieldTag.IsCreate, 0, self.callee_address is None), - (self.id, TxContextFieldTag.Value, 0, RLC(self.value, randomness)), - (self.id, TxContextFieldTag.CallDataLength, 0, len(self.call_data)), - (self.id, TxContextFieldTag.CallDataGasCost, 0, self.call_data_gas_cost()), + TxTableRow(FQ(self.id), TxContextFieldTag.Nonce, FQ(0), FQ(self.nonce)), + TxTableRow(FQ(self.id), TxContextFieldTag.Gas, FQ(0), FQ(self.gas)), + TxTableRow( + FQ(self.id), + TxContextFieldTag.GasPrice, + FQ(0), + RLC(self.gas_price, randomness).value, + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.CallerAddress, FQ(0), FQ(self.caller_address) + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.CalleeAddress, FQ(0), FQ(self.callee_address) + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.IsCreate, FQ(0), FQ(self.callee_address is None) + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.Value, FQ(0), RLC(self.value, randomness).value + ), + TxTableRow( + FQ(self.id), TxContextFieldTag.CallDataLength, FQ(0), FQ(len(self.call_data)) + ), + TxTableRow( + FQ(self.id), + TxContextFieldTag.CallDataGasCost, + FQ(0), + FQ(self.call_data_gas_cost()), + ), ], map( - lambda item: (self.id, TxContextFieldTag.CallData, item[0], item[1]), + lambda item: TxTableRow( + FQ(self.id), TxContextFieldTag.CallData, FQ(item[0]), FQ(item[1]) + ), enumerate(self.call_data), ), ) @@ -187,8 +218,8 @@ def push(self, value: Any, n_bytes: int = 32) -> Bytecode: return self - def hash(self) -> int: - return int.from_bytes(keccak256(self.code), "big") + def hash(self) -> U256: + return U256(int.from_bytes(keccak256(self.code), "big")) def table_assignments(self, randomness: int) -> Iterator[BytecodeTableRow]: class BytecodeIterator: @@ -235,9 +266,9 @@ class Account: def __init__( self, - address: U160 = 0, - nonce: U256 = 0, - balance: U256 = 0, + address: U160 = U160(0), + nonce: U256 = U256(0), + balance: U256 = U256(0), code: Optional[Bytecode] = None, storage: Optional[Storage] = None, ) -> None: @@ -245,7 +276,7 @@ def __init__( self.nonce = nonce self.balance = balance self.code = Bytecode() if code is None else code - self.storage = dict() if storage is None else storage + self.storage = Storage(dict()) if storage is None else storage def code_hash(self) -> U256: return self.code.hash() diff --git a/src/zkevm_specs/opcode/mload_mstore.py b/src/zkevm_specs/opcode/mload_mstore.py index 5bf70d977..7e0f7e606 100644 --- a/src/zkevm_specs/opcode/mload_mstore.py +++ b/src/zkevm_specs/opcode/mload_mstore.py @@ -18,6 +18,7 @@ def address_low( _sum = sum(x * (2 ** (8 * i)) for i, x in enumerate(address[:NUM_ADDRESS_BYTES_USED])) return U64(_sum) + @is_circuit_code def address_high( address: Sequence[U8], diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index 4ae4e6b1c..f9be2314a 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -7,7 +7,7 @@ def _hash_fq(v: FQ) -> int: return hash(v.n) -FQ.__hash__ = _hash_fq # type: ignore +FQ.__hash__ = _hash_fq # type: ignore IntOrFQ = Union[int, FQ] From 4f0ade495ae2dffe267d09f70aaa9e4591ead4e7 Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Tue, 15 Feb 2022 18:57:07 +0800 Subject: [PATCH 11/15] instruction --- src/zkevm_specs/evm/instruction.py | 191 +++++++++++++++-------------- src/zkevm_specs/evm/table.py | 32 ++++- src/zkevm_specs/util/arithmetic.py | 2 +- 3 files changed, 124 insertions(+), 101 deletions(-) diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index fb3ccee2c..2185c18bc 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -27,6 +27,9 @@ RWTableTag, RWTableRow, FixedTableRow, + BlockTableRow, + TxTableRow, + BytecodeTableRow, ) @@ -49,13 +52,16 @@ def __init__(self, kind: TransitionKind, value: Optional[int] = None) -> None: self.kind = kind self.value = value - def same() -> Transition: + @classmethod + def same(cls) -> Transition: return Transition(TransitionKind.Same) - def delta(delta: int): + @classmethod + def delta(cls, delta: int): return Transition(TransitionKind.Delta, delta) - def to(to: int): + @classmethod + def to(cls, to: int): return Transition(TransitionKind.To, to) @@ -107,7 +113,7 @@ def constrain_bool(self, num: FQ): def constrain_gas_left_not_underflow(self, gas_left: FQ): self.range_check(gas_left, N_BYTES_GAS) - def constrain_step_state_transition(self, **kwargs: Mapping[str, Transition]): + def constrain_step_state_transition(self, **kwargs: Transition): keys = set( [ "rw_counter", @@ -198,7 +204,7 @@ def step_state_transition_in_same_context( ) def sum(self, values: Sequence[FQ]) -> FQ: - return sum(values) + return FQ(sum(values)) def is_zero(self, value: Union[FQ, RLC]) -> bool: return value == 0 @@ -223,8 +229,8 @@ def pair_select(self, value: FQ, lhs: FQ, rhs: FQ) -> Tuple[bool, bool]: def constant_divmod( self, numerator: IntOrFQ, denominator: IntOrFQ, n_bytes: int ) -> Tuple[FQ, FQ]: - quotient, remainder = divmod(FQ(numerator).n, FQ(denominator).n) - quotient, remainder = FQ(quotient), FQ(remainder) + _quotient, _remainder = divmod(FQ(numerator).n, FQ(denominator).n) + quotient, remainder = FQ(_quotient), FQ(_remainder) self.range_check(quotient, n_bytes) return quotient, remainder @@ -246,12 +252,11 @@ def add_words(self, addends: Sequence[RLC]) -> Tuple[RLC, FQ]: addends_lo, addends_hi = list(zip(*map(self.word_to_lo_hi, addends))) carry_lo, sum_lo = divmod(self.sum(addends_lo).n, 1 << 128) - carry_hi, sum_hi = divmod((self.sum(addends_hi) + carry_lo).n, 1 << 128) + carry_hi, sum_hi = divmod(self.sum(addends_hi).n + carry_lo, 1 << 128) sum_bytes = sum_lo.to_bytes(16, "little") + sum_hi.to_bytes(16, "little") - carry_hi = FQ(carry_hi) - return RLC(sum_bytes, self.randomness), carry_hi + return RLC(sum_bytes, self.randomness, 32), FQ(carry_hi) def sub_word(self, minuend: RLC, subtrahend: RLC) -> Tuple[RLC, bool]: minuend_lo, minuend_hi = self.word_to_lo_hi(minuend) @@ -273,19 +278,12 @@ def mul_word_by_u64(self, multiplicand: RLC, multiplier: FQ) -> Tuple[RLC, FQ]: quotient_hi, product_hi = divmod((multiplicand_hi * multiplier + quotient_lo).n, 1 << 128) product_bytes = product_lo.to_bytes(16, "little") + product_hi.to_bytes(16, "little") - quotient_hi = FQ(quotient_hi) - return RLC(product_bytes, self.randomness), quotient_hi + return RLC(product_bytes, self.randomness), FQ(quotient_hi) def rlc_to_le_bytes(self, rlc: RLC) -> bytes: return rlc.le_bytes - def rlc_to_fq_unchecked(self, rlc: RLC, n_bytes: int) -> FQ: - rlc_le_bytes = self.rlc_to_le_bytes(rlc) - return self.bytes_to_fq(rlc_le_bytes[:n_bytes]), self.is_zero( - self.sum(rlc_le_bytes[n_bytes:]) - ) - def rlc_to_fq_exact(self, rlc: RLC, n_bytes: int) -> FQ: rlc_le_bytes = self.rlc_to_le_bytes(rlc) @@ -311,7 +309,7 @@ def bytes_to_fq(self, value: bytes) -> FQ: return FQ(int.from_bytes(value, "little")) def range_lookup(self, value: FQ, range: int): - self.tables.fixed_lookup([FixedTableTag.range_table_tag(range), value, 0, 0]) + self.tables.fixed_lookup(FixedTableTag.range_table_tag(range), value) def byte_range_lookup(self, value: FQ): assert isinstance(value, FQ), f"Expect type FQ, but get type {type(value)}" @@ -329,24 +327,26 @@ def fixed_lookup(self, tag: FixedTableTag, inputs: Sequence[FQ]) -> FixedTableRo return self.tables.fixed_lookup(tag, *inputs) def block_context_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ.zero()) -> FQ: - return self.tables.block_lookup([tag, index])[2] + return self.tables.block_lookup(tag, index).value def tx_context_lookup( self, tx_id: FQ, field_tag: TxContextFieldTag, index: FQ = FQ.zero() - ) -> Union[FQ, RLC]: - return self.tables.tx_lookup([tx_id, field_tag, index])[3] + ) -> FQ: + return self.tables.tx_lookup(tx_id, field_tag, index).value def tx_calldata_lookup(self, tx_id: FQ, index: FQ) -> FQ: - return self.tables.tx_lookup([tx_id, TxContextFieldTag.CallData, index])[3] + return self.tables.tx_lookup(tx_id, TxContextFieldTag.CallData, index).value def bytecode_lookup(self, bytecode_hash: RLC, index: FQ, is_code: FQ) -> FQ: - return self.tables.bytecode_lookup([bytecode_hash, index, Tables._, is_code])[2] + return self.tables.bytecode_lookup(bytecode_hash.value, index, is_code).byte def tx_gas_price(self, tx_id: FQ) -> FQ: return self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice) def responsible_opcode_lookup(self, opcode: int): - self.fixed_lookup(FixedTableTag.ResponsibleOpcode, [self.curr.execution_state, opcode]) + self.tables.fixed_lookup( + FixedTableTag.ResponsibleOpcode, FQ(self.curr.execution_state.value), FQ(opcode) + ) def opcode_lookup(self, is_code: bool) -> FQ: index = self.curr.program_counter + self.program_counter_offset @@ -359,26 +359,16 @@ def opcode_lookup_at(self, index: FQ, is_code: bool) -> FQ: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.code_source, index, is_code) + return self.bytecode_lookup(self.curr.code_source, index, FQ(is_code)) - def rw_lookup(self, rw: RW, tag: RWTableTag, inputs: Sequence[int], rw_counter: Optional[int] = None) -> RWTableRow: + def rw_lookup( + self, rw: RW, tag: RWTableTag, inputs: Sequence[IntOrFQ], rw_counter: Optional[FQ] = None + ) -> RWTableRow: if rw_counter is None: rw_counter = self.curr.rw_counter + self.rw_counter_offset self.rw_counter_offset += 1 - return self.tables.rw_lookup([rw_counter, rw, tag] + inputs) - def state_write_only_persistent( - self, - tag: RWTableTag, - inputs: Sequence[int], - is_persistent: bool, - ) -> RWTableRow: - assert tag.write_only_persistent() - - if is_persistent: - return self.rw_lookup(RW.Write, tag, inputs) - - return 10 * [None] + return self.tables.rw_lookup(rw_counter, rw, tag, *inputs) def state_write_with_reversion( self, @@ -386,61 +376,73 @@ def state_write_with_reversion( inputs: Sequence[int], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + _state_write_counter: Optional[FQ] = None, ) -> RWTableRow: assert tag.write_with_reversion() row = self.rw_lookup(RW.Write, tag, inputs) - if state_write_counter is None: - state_write_counter = self.curr.state_write_counter + self.state_write_counter_offset + if _state_write_counter is None: self.state_write_counter_offset += 1 + state_write_counter = ( + self.curr.state_write_counter + self.state_write_counter_offset + if _state_write_counter is None + else _state_write_counter + ) + rw_counter = rw_counter_end_of_reversion - state_write_counter if not is_persistent: # Swap value and value_prev - inputs = list(row[3:]) - inputs[-3], inputs[-4] = inputs[-4], inputs[-3] - self.rw_lookup(RW.Write, tag, inputs, rw_counter=rw_counter) + inputs2 = [ + row.key2, + row.key3, + row.key4, + row.value_prev, + row.value, + row.aux1, + row.aux2, + ] + self.rw_lookup(RW.Write, tag, inputs2, rw_counter=rw_counter) return row def call_context_lookup( - self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[int] = None - ) -> Union[FQ, RLC]: - if call_id is None: - call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag])[-4] + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, _call_id: Optional[FQ] = None + ) -> FQ: + call_id = self.curr.call_id if _call_id is None else _call_id + + return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag.value]).value def stack_pop(self) -> Union[FQ, RLC]: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 - return self.stack_lookup(False, stack_pointer_offset) + return self.stack_lookup(RW.Read, stack_pointer_offset) def stack_push(self) -> Union[FQ, RLC]: self.stack_pointer_offset -= 1 - return self.stack_lookup(True, self.stack_pointer_offset) + return self.stack_lookup(RW.Write, self.stack_pointer_offset) def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[FQ, RLC]: stack_pointer = self.curr.stack_pointer + stack_pointer_offset - return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer])[-4] + return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer]).value - def memory_write(self, memory_address: int, call_id: Optional[int] = None) -> FQ: + def memory_write(self, memory_address: int, call_id: Optional[FQ] = None) -> FQ: return self.memory_lookup(RW.Write, memory_address, call_id) - def memory_lookup(self, rw: RW, memory_address: int, call_id: Optional[int] = None) -> FQ: - if call_id is None: - call_id = self.curr.call_id - return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address])[-4] + def memory_lookup(self, rw: RW, memory_address: int, _call_id: Optional[FQ] = None) -> FQ: + call_id = self.curr.call_id if _call_id is None else _call_id + + return self.rw_lookup(rw, RWTableTag.Memory, [call_id, memory_address]).value def tx_refund_read(self, tx_id) -> FQ: row = self.rw_lookup(RW.Read, RWTableTag.TxRefund, [tx_id]) - return row[-4] + return row.value def account_read(self, account_address: int, account_field_tag: AccountFieldTag) -> FQ: row = self.rw_lookup(RW.Read, RWTableTag.Account, [account_address, account_field_tag]) - return row[-4] + return row.value def account_write( self, @@ -452,7 +454,7 @@ def account_write( RWTableTag.Account, [account_address, account_field_tag], ) - return row[-4], row[-3] + return row.value, row.value_prev def account_write_with_reversion( self, @@ -460,7 +462,7 @@ def account_write_with_reversion( account_field_tag: AccountFieldTag, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> Tuple[FQ, FQ]: row = self.state_write_with_reversion( RWTableTag.Account, @@ -469,22 +471,22 @@ def account_write_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4], row[-3] + return row.value, row.value_prev - def add_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: + def add_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[FQ, FQ]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([balance_prev, *values]) - self.constrain_equal(balance, result) + result, carry = self.add_words([RLC(balance_prev, self.randomness), *values]) + self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev def add_balance_with_reversion( self, account_address: int, - values: Sequence[int], + values: Sequence[RLC], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> Tuple[FQ, FQ]: balance, balance_prev = self.account_write_with_reversion( account_address, @@ -493,25 +495,25 @@ def add_balance_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - result, carry = self.add_words([balance_prev, *values]) - self.constrain_equal(balance, result) + result, carry = self.add_words([RLC(balance_prev, self.randomness), *values]) + self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[int]) -> Tuple[FQ, FQ]: + def sub_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[FQ, FQ]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([balance, *values]) - self.constrain_equal(balance_prev, result) + result, carry = self.add_words([RLC(balance, self.randomness), *values]) + self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev def sub_balance_with_reversion( self, account_address: int, - values: Sequence[int], + values: Sequence[RLC], is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> Tuple[FQ, FQ]: balance, balance_prev = self.account_write_with_reversion( account_address, @@ -520,8 +522,8 @@ def sub_balance_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - result, carry = self.add_words([balance, *values]) - self.constrain_equal(balance_prev, result) + result, carry = self.add_words([RLC(balance, self.randomness), *values]) + self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev @@ -535,7 +537,7 @@ def add_account_to_access_list( RWTableTag.TxAccessListAccount, [tx_id, account_address, 0, 1], ) - return row[-4] - row[-3] + return row.value - row.value_prev def add_account_to_access_list_with_reversion( self, @@ -543,7 +545,7 @@ def add_account_to_access_list_with_reversion( account_address: int, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> FQ: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccount, @@ -552,7 +554,7 @@ def add_account_to_access_list_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4] - row[-3] + return row.value - row.value_prev def add_account_storage_to_access_list( self, @@ -565,7 +567,7 @@ def add_account_storage_to_access_list( RWTableTag.TxAccessListAccountStorage, [tx_id, account_address, storage_key, 1], ) - return row[-4] - row[-3] + return (row.value - row.value_prev) == 1 def add_account_storage_to_access_list_with_reversion( self, @@ -574,7 +576,7 @@ def add_account_storage_to_access_list_with_reversion( storage_key: int, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, + state_write_counter: Optional[FQ] = None, ) -> bool: row = self.state_write_with_reversion( RWTableTag.TxAccessListAccountStorage, @@ -583,14 +585,14 @@ def add_account_storage_to_access_list_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - return row[-4] - row[-3] + return (row.value - row.value_prev) == 1 def transfer_with_gas_fee( self, sender_address: int, receiver_address: int, - value: int, - gas_fee: int, + value: RLC, + gas_fee: RLC, is_persistent: bool, rw_counter_end_of_reversion: int, ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: @@ -612,11 +614,11 @@ def transfer( self, sender_address: int, receiver_address: int, - value: int, + value: RLC, is_persistent: bool, rw_counter_end_of_reversion: int, - state_write_counter: Optional[int] = None, - ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + state_write_counter: Optional[FQ] = None, + ) -> Tuple[Tuple[FQ, FQ], Tuple[FQ, FQ]]: sender_balance_pair = self.sub_balance_with_reversion( sender_address, [value], @@ -634,11 +636,11 @@ def transfer( return sender_balance_pair, receiver_balance_pair def memory_offset_and_length(self, offset: RLC, length: RLC) -> Tuple[FQ, FQ]: - length = self.rlc_to_fq_exact(length, N_BYTES_MEMORY_SIZE) - if self.is_zero(length): + _length = self.rlc_to_fq_exact(length, N_BYTES_MEMORY_SIZE) + if self.is_zero(_length): return FQ.zero(), FQ.zero() - offset = self.rlc_to_fq_exact(offset, N_BYTES_MEMORY_ADDRESS) - return offset, length + _offset = self.rlc_to_fq_exact(offset, N_BYTES_MEMORY_ADDRESS) + return _offset, _length def memory_gas_cost(self, memory_size: FQ) -> FQ: quadratic_cost, _ = self.constant_divmod( @@ -652,7 +654,7 @@ def memory_expansion_constant_length(self, offset: FQ, length: FQ) -> Tuple[FQ, next_memory_size = self.max(self.curr.memory_size, memory_size, N_BYTES_MEMORY_SIZE) - memory_gas_cost = self.memory_expansion_gas_cost(self.curr.memory_size) + memory_gas_cost = self.memory_gas_cost(self.curr.memory_size) memory_gas_cost_next = self.memory_gas_cost(next_memory_size) memory_expansion_gas_cost = memory_gas_cost_next - memory_gas_cost @@ -671,6 +673,7 @@ def memory_expansion_dynamic_length( next_memory_size = self.max(self.curr.memory_size, cd_memory_size, N_BYTES_MEMORY_SIZE) if rd_offset is not None: + rd_length = 0 if _rd_length is None else _rd_length rd_memory_size, _ = self.constant_divmod( rd_offset + rd_length + 31, 32, N_BYTES_MEMORY_SIZE ) diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 124671a6a..75feb83aa 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -250,13 +250,13 @@ def __init__(self, table_name: str, diff: Set[str]) -> None: class LookupUnsatFailure(Exception): - def __init__(self, table_name: str, inputs: Tuple[int, ...]) -> None: + def __init__(self, table_name: str, inputs: Any) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is unsatisfied on inputs {inputs}" class LookupAmbiguousFailure(Exception): - def __init__(self, table_name: str, inputs: Tuple[int, ...], matched_rows: Sequence[Any]) -> None: + def __init__(self, table_name: str, inputs: Any, matched_rows: Sequence[Any]) -> None: self.inputs = inputs self.message = f"Lookup {table_name} is ambiguous on inputs {inputs}, ${len(matched_rows)} matched rows found: {matched_rows}" @@ -364,8 +364,12 @@ def block_lookup(self, tag: BlockContextFieldTag, index: FQ = FQ(0)) -> BlockTab query: Dict[str, Optional[IntOrFQ]] = {"tag": tag, "block_number_or_zero": index} return _lookup(BlockTableRow, self.block_table, query) - def tx_lookup(self, tx_id: int, field_tag: TxContextFieldTag, index: FQ) -> TxTableRow: - query: Dict[str, Optional[IntOrFQ]] = {"tx_id": tx_id, "tag": field_tag, "call_data_index_or_zero": index} + def tx_lookup(self, tx_id: FQ, field_tag: TxContextFieldTag, index: FQ) -> TxTableRow: + query: Dict[str, Optional[IntOrFQ]] = { + "tx_id": tx_id, + "tag": field_tag, + "call_data_index_or_zero": index, + } return _lookup(TxTableRow, self.tx_table, query) def bytecode_lookup(self, bytecode_hash: FQ, index: FQ, is_code: FQ) -> BytecodeTableRow: @@ -376,8 +380,24 @@ def bytecode_lookup(self, bytecode_hash: FQ, index: FQ, is_code: FQ) -> Bytecode } return _lookup(BytecodeTableRow, self.bytecode_table, query) - def rw_lookup(self, rw_counter: FQ, rw: RW, tag: RWTableTag, **other_queries: IntOrFQ) -> RWTableRow: - query: Dict[str, Optional[IntOrFQ]] = {"rw_counter": rw_counter, "rw": int(rw.value), "tag": tag, **other_queries} + def rw_lookup( + self, rw_counter: FQ, rw: RW, tag: RWTableTag, *other_queries: IntOrFQ + ) -> RWTableRow: + rest_keys = [ + "key2", + "key3", + "key4", + "value", + "value_prev", + "aux1", + "aux2", + ] + query: Dict[str, Optional[IntOrFQ]] = { + "rw_counter": rw_counter, + "rw": int(rw.value), + "tag": tag, + **zip(rest_keys, other_queries), + } return _lookup(RWTableRow, self.rw_table, query) diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index f9be2314a..acfb90f96 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -25,7 +25,7 @@ class RLC: value: FQ def __init__( - self, int_or_bytes: Union[IntOrFQ, bytes], randomness: int, n_bytes: int = 32 + self, int_or_bytes: Union[IntOrFQ, bytes], randomness: IntOrFQ, n_bytes: int = 32 ) -> None: if isinstance(int_or_bytes, int): assert ( From d56888ca749b62f03348dea79061aee7f428625b Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Thu, 17 Feb 2022 17:21:21 +0800 Subject: [PATCH 12/15] slt sgt --- src/zkevm_specs/evm/execution/slt_sgt.py | 4 ++-- src/zkevm_specs/evm/instruction.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/zkevm_specs/evm/execution/slt_sgt.py b/src/zkevm_specs/evm/execution/slt_sgt.py index 752ef1bc1..0cff3a022 100644 --- a/src/zkevm_specs/evm/execution/slt_sgt.py +++ b/src/zkevm_specs/evm/execution/slt_sgt.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple +from zkevm_specs.util import FQ from ..instruction import Instruction, Transition from ..opcode import Opcode @@ -7,7 +7,7 @@ def scmp(instruction: Instruction): opcode = instruction.opcode_lookup(True) - is_sgt, _ = instruction.pair_select(opcode, Opcode.SGT, Opcode.SLT) + is_sgt, _ = instruction.pair_select(opcode, FQ(Opcode.SGT.value), FQ(Opcode.SLT.value)) a = instruction.stack_pop() b = instruction.stack_pop() diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index 2185c18bc..c15aa5ee9 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -415,16 +415,16 @@ def call_context_lookup( return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag.value]).value - def stack_pop(self) -> Union[FQ, RLC]: + def stack_pop(self) -> FQ: stack_pointer_offset = self.stack_pointer_offset self.stack_pointer_offset += 1 return self.stack_lookup(RW.Read, stack_pointer_offset) - def stack_push(self) -> Union[FQ, RLC]: + def stack_push(self) -> FQ: self.stack_pointer_offset -= 1 return self.stack_lookup(RW.Write, self.stack_pointer_offset) - def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> Union[FQ, RLC]: + def stack_lookup(self, rw: RW, stack_pointer_offset: int) -> FQ: stack_pointer = self.curr.stack_pointer + stack_pointer_offset return self.rw_lookup(rw, RWTableTag.Stack, [self.curr.call_id, stack_pointer]).value @@ -665,7 +665,7 @@ def memory_expansion_dynamic_length( cd_offset: FQ, cd_length: FQ, rd_offset: Optional[FQ] = None, - rd_length: Optional[FQ] = None, + _rd_length: Optional[FQ] = None, ) -> Tuple[FQ, FQ]: cd_memory_size, _ = self.constant_divmod( cd_offset + cd_length + 31, 32, N_BYTES_MEMORY_SIZE From 46b2459d51208ff4a3ad4979acabb0fdcb0a493c Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Thu, 17 Feb 2022 20:49:41 +0800 Subject: [PATCH 13/15] make rlc FQ --- src/zkevm_specs/bytecode.py | 2 +- src/zkevm_specs/evm/step.py | 2 +- src/zkevm_specs/evm/typing.py | 6 +++--- src/zkevm_specs/util/arithmetic.py | 5 ++--- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index d198dab3b..1beca718a 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -178,7 +178,7 @@ def assign_push_table(): # Generate keccak table -def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int): +def assign_keccak_table(bytecodes: Sequence[bytes], randomness: FQ): keccak_table = [] for bytecode in bytecodes: hash = RLC(bytes(reversed(keccak256(bytecode))), randomness) diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index 6914eb000..862189c8a 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -50,7 +50,7 @@ def __init__( call_id: int = 0, is_root: bool = False, is_create: bool = False, - code_source: RLC = RLC(0, 0), + code_source: RLC = RLC(0, FQ(0)), program_counter: int = 0, stack_pointer: int = 1024, gas_left: int = 0, diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 528497281..5b29335e8 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -61,7 +61,7 @@ def __init__( self.base_fee = base_fee self.history_hashes = history_hashes - def table_assignments(self, randomness: int) -> List[BlockTableRow]: + def table_assignments(self, randomness: FQ) -> List[BlockTableRow]: return [ BlockTableRow(BlockContextFieldTag.Coinbase, FQ(0), FQ(self.coinbase)), BlockTableRow(BlockContextFieldTag.GasLimit, FQ(0), FQ(self.gas_limit)), @@ -127,7 +127,7 @@ def call_data_gas_cost(self) -> int: 0, ) - def table_assignments(self, randomness: int) -> Iterator[TxTableRow]: + def table_assignments(self, randomness: FQ) -> Iterator[TxTableRow]: return chain( [ TxTableRow(FQ(self.id), TxContextFieldTag.Nonce, FQ(0), FQ(self.nonce)), @@ -221,7 +221,7 @@ def push(self, value: Any, n_bytes: int = 32) -> Bytecode: def hash(self) -> U256: return U256(int.from_bytes(keccak256(self.code), "big")) - def table_assignments(self, randomness: int) -> Iterator[BytecodeTableRow]: + def table_assignments(self, randomness: FQ) -> Iterator[BytecodeTableRow]: class BytecodeIterator: idx: int push_data_left: int diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index acfb90f96..b199f71c1 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -11,9 +11,8 @@ def _hash_fq(v: FQ) -> int: IntOrFQ = Union[int, FQ] -def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], _factor: int) -> FQ: +def fp_linear_combine(le_bytes: Union[bytes, Sequence[int]], factor: FQ) -> FQ: ret = FQ.zero() - factor = FQ(_factor) for byte in reversed(le_bytes): assert 0 <= byte < 256, "Each byte in le_bytes for linear combination should fit in 8-bit" ret = ret * factor + byte @@ -25,7 +24,7 @@ class RLC: value: FQ def __init__( - self, int_or_bytes: Union[IntOrFQ, bytes], randomness: IntOrFQ, n_bytes: int = 32 + self, int_or_bytes: Union[IntOrFQ, bytes], randomness: FQ, n_bytes: int = 32 ) -> None: if isinstance(int_or_bytes, int): assert ( From 6e8e1b3b6d587288e0e1cf8925a0f954740aa33f Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Thu, 17 Feb 2022 20:53:28 +0800 Subject: [PATCH 14/15] ignore rest of the typing --- src/zkevm_specs/evm/execution/__init__.py | 1 + src/zkevm_specs/evm/execution/add.py | 1 + src/zkevm_specs/evm/execution/begin_tx.py | 1 + src/zkevm_specs/evm/execution/block_coinbase.py | 1 + src/zkevm_specs/evm/execution/block_timestamp.py | 1 + src/zkevm_specs/evm/execution/calldatacopy.py | 1 + src/zkevm_specs/evm/execution/calldatasize.py | 1 + src/zkevm_specs/evm/execution/caller.py | 1 + src/zkevm_specs/evm/execution/callvalue.py | 1 + src/zkevm_specs/evm/execution/end_block.py | 1 + src/zkevm_specs/evm/execution/end_tx.py | 1 + src/zkevm_specs/evm/execution/gas.py | 1 + src/zkevm_specs/evm/execution/gasprice.py | 1 + src/zkevm_specs/evm/execution/jump.py | 1 + src/zkevm_specs/evm/execution/jumpi.py | 1 + src/zkevm_specs/evm/execution/memory_copy.py | 3 +-- src/zkevm_specs/evm/execution/push.py | 1 + src/zkevm_specs/evm/execution/selfbalance.py | 1 + src/zkevm_specs/evm/execution/slt_sgt.py | 2 ++ src/zkevm_specs/evm/main.py | 1 + 20 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index f282cd9cf..68cac17b3 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -1,3 +1,4 @@ +# type: ignore from typing import Callable, Dict from ..execution_state import ExecutionState diff --git a/src/zkevm_specs/evm/execution/add.py b/src/zkevm_specs/evm/execution/add.py index 8bd6f2e54..a0e09b61a 100644 --- a/src/zkevm_specs/evm/execution/add.py +++ b/src/zkevm_specs/evm/execution/add.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index 58de3305c..edaa819fc 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH from ..execution_state import ExecutionState from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/execution/block_coinbase.py b/src/zkevm_specs/evm/execution/block_coinbase.py index 55d0b3b20..f2ea8f845 100644 --- a/src/zkevm_specs/evm/execution/block_coinbase.py +++ b/src/zkevm_specs/evm/execution/block_coinbase.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/block_timestamp.py b/src/zkevm_specs/evm/execution/block_timestamp.py index c503e2750..6ab10c9dc 100644 --- a/src/zkevm_specs/evm/execution/block_timestamp.py +++ b/src/zkevm_specs/evm/execution/block_timestamp.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import BlockContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/calldatacopy.py b/src/zkevm_specs/evm/execution/calldatacopy.py index 599b3dc58..251612d0e 100644 --- a/src/zkevm_specs/evm/execution/calldatacopy.py +++ b/src/zkevm_specs/evm/execution/calldatacopy.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import N_BYTES_MEMORY_ADDRESS, FQ from ..execution_state import ExecutionState from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/execution/calldatasize.py b/src/zkevm_specs/evm/execution/calldatasize.py index 2495577a0..1797b62cd 100644 --- a/src/zkevm_specs/evm/execution/calldatasize.py +++ b/src/zkevm_specs/evm/execution/calldatasize.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/caller.py b/src/zkevm_specs/evm/execution/caller.py index 58bf046ea..e633d6946 100644 --- a/src/zkevm_specs/evm/execution/caller.py +++ b/src/zkevm_specs/evm/execution/caller.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/callvalue.py b/src/zkevm_specs/evm/execution/callvalue.py index 4e37a5b14..16ecb5907 100644 --- a/src/zkevm_specs/evm/execution/callvalue.py +++ b/src/zkevm_specs/evm/execution/callvalue.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/end_block.py b/src/zkevm_specs/evm/execution/end_block.py index 4ba98a1e9..6d3e81a7a 100644 --- a/src/zkevm_specs/evm/execution/end_block.py +++ b/src/zkevm_specs/evm/execution/end_block.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import CallContextFieldTag diff --git a/src/zkevm_specs/evm/execution/end_tx.py b/src/zkevm_specs/evm/execution/end_tx.py index e9d295b32..76f5cb89b 100644 --- a/src/zkevm_specs/evm/execution/end_tx.py +++ b/src/zkevm_specs/evm/execution/end_tx.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED from ..execution_state import ExecutionState from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/execution/gas.py b/src/zkevm_specs/evm/execution/gas.py index d0a0faea5..a56e3d237 100644 --- a/src/zkevm_specs/evm/execution/gas.py +++ b/src/zkevm_specs/evm/execution/gas.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import CallContextFieldTag, TxContextFieldTag diff --git a/src/zkevm_specs/evm/execution/gasprice.py b/src/zkevm_specs/evm/execution/gasprice.py index 3323692f1..ee2751f99 100644 --- a/src/zkevm_specs/evm/execution/gasprice.py +++ b/src/zkevm_specs/evm/execution/gasprice.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode from ..table import CallContextFieldTag, TxContextFieldTag diff --git a/src/zkevm_specs/evm/execution/jump.py b/src/zkevm_specs/evm/execution/jump.py index 8e46d6d1f..b33a8cbc9 100644 --- a/src/zkevm_specs/evm/execution/jump.py +++ b/src/zkevm_specs/evm/execution/jump.py @@ -1,3 +1,4 @@ +# type: ignore from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/jumpi.py b/src/zkevm_specs/evm/execution/jumpi.py index c485b0b74..d8a2d66f0 100644 --- a/src/zkevm_specs/evm/execution/jumpi.py +++ b/src/zkevm_specs/evm/execution/jumpi.py @@ -1,3 +1,4 @@ +# type: ignore from ...util.param import N_BYTES_PROGRAM_COUNTER from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/memory_copy.py b/src/zkevm_specs/evm/execution/memory_copy.py index cabb93b71..43ac76008 100644 --- a/src/zkevm_specs/evm/execution/memory_copy.py +++ b/src/zkevm_specs/evm/execution/memory_copy.py @@ -1,3 +1,4 @@ +# type: ignore from ...util import FQ, N_BYTES_MEMORY_SIZE from ..execution_state import ExecutionState from ..instruction import Instruction, Transition @@ -17,8 +18,6 @@ def copy_to_memory(instruction: Instruction): instruction, MAX_COPY_BYTES, aux.src_addr, aux.src_addr_end, aux.bytes_left ) - data = [] - rw_counter_delta = 0 for i in range(MAX_COPY_BYTES): if not buffer_reader.read_flag(i): byte = FQ.zero() diff --git a/src/zkevm_specs/evm/execution/push.py b/src/zkevm_specs/evm/execution/push.py index cb5602684..a2927c9e9 100644 --- a/src/zkevm_specs/evm/execution/push.py +++ b/src/zkevm_specs/evm/execution/push.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/selfbalance.py b/src/zkevm_specs/evm/execution/selfbalance.py index b69448757..b746573e4 100644 --- a/src/zkevm_specs/evm/execution/selfbalance.py +++ b/src/zkevm_specs/evm/execution/selfbalance.py @@ -1,3 +1,4 @@ +# type: ignore from ..instruction import Instruction, Transition from ..table import AccountFieldTag, CallContextFieldTag from ..opcode import Opcode diff --git a/src/zkevm_specs/evm/execution/slt_sgt.py b/src/zkevm_specs/evm/execution/slt_sgt.py index 0cff3a022..1d06775be 100644 --- a/src/zkevm_specs/evm/execution/slt_sgt.py +++ b/src/zkevm_specs/evm/execution/slt_sgt.py @@ -1,3 +1,5 @@ +# type: ignore + from zkevm_specs.util import FQ from ..instruction import Instruction, Transition diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index f990f9100..f412c3289 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -1,3 +1,4 @@ +# type: ignore from typing import Sequence from ..util import FQ From 1ff5aa7e8680b2857c240cc47abe408282b9907d Mon Sep 17 00:00:00 2001 From: ChihChengLiang Date: Thu, 17 Feb 2022 23:38:48 +0800 Subject: [PATCH 15/15] attempt to fix tests --- src/zkevm_specs/evm/execution/begin_tx.py | 2 +- src/zkevm_specs/evm/instruction.py | 24 ++++++------ src/zkevm_specs/evm/table.py | 46 +++++++++++++++-------- src/zkevm_specs/evm/typing.py | 23 ++++++------ src/zkevm_specs/util/arithmetic.py | 3 ++ tests/evm/test_gasprice.py | 13 ++++++- tests/test_bytecode_circuit.py | 8 ++-- 7 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/zkevm_specs/evm/execution/begin_tx.py b/src/zkevm_specs/evm/execution/begin_tx.py index edaa819fc..1fd2c9814 100644 --- a/src/zkevm_specs/evm/execution/begin_tx.py +++ b/src/zkevm_specs/evm/execution/begin_tx.py @@ -1,5 +1,5 @@ # type: ignore -from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH +from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH, RLC from ..execution_state import ExecutionState from ..instruction import Instruction, Transition from ..precompiled import PrecompiledAddress diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index c15aa5ee9..d37b3d5ea 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -409,11 +409,11 @@ def state_write_with_reversion( return row def call_context_lookup( - self, field_tag: CallContextFieldTag, rw: RW = RW.Read, _call_id: Optional[FQ] = None + self, field_tag: CallContextFieldTag, rw: RW = RW.Read, call_id: Optional[FQ] = None ) -> FQ: - call_id = self.curr.call_id if _call_id is None else _call_id + _call_id = self.curr.call_id if call_id is None else call_id - return self.rw_lookup(rw, RWTableTag.CallContext, [call_id, field_tag.value]).value + return self.rw_lookup(rw, RWTableTag.CallContext, [_call_id, field_tag.value]).value def stack_pop(self) -> FQ: stack_pointer_offset = self.stack_pointer_offset @@ -463,7 +463,7 @@ def account_write_with_reversion( is_persistent: bool, rw_counter_end_of_reversion: int, state_write_counter: Optional[FQ] = None, - ) -> Tuple[FQ, FQ]: + ) -> Tuple[RLC, RLC]: row = self.state_write_with_reversion( RWTableTag.Account, [account_address, account_field_tag], @@ -473,9 +473,9 @@ def account_write_with_reversion( ) return row.value, row.value_prev - def add_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[FQ, FQ]: + def add_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([RLC(balance_prev, self.randomness), *values]) + result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev @@ -487,7 +487,7 @@ def add_balance_with_reversion( is_persistent: bool, rw_counter_end_of_reversion: int, state_write_counter: Optional[FQ] = None, - ) -> Tuple[FQ, FQ]: + ) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write_with_reversion( account_address, AccountFieldTag.Balance, @@ -495,14 +495,14 @@ def add_balance_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - result, carry = self.add_words([RLC(balance_prev, self.randomness), *values]) + result, carry = self.add_words([balance_prev, *values]) self.constrain_equal(balance, result.value) self.constrain_zero(carry) return balance, balance_prev - def sub_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[FQ, FQ]: + def sub_balance(self, account_address: int, values: Sequence[RLC]) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write(account_address, AccountFieldTag.Balance) - result, carry = self.add_words([RLC(balance, self.randomness), *values]) + result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev @@ -514,7 +514,7 @@ def sub_balance_with_reversion( is_persistent: bool, rw_counter_end_of_reversion: int, state_write_counter: Optional[FQ] = None, - ) -> Tuple[FQ, FQ]: + ) -> Tuple[RLC, RLC]: balance, balance_prev = self.account_write_with_reversion( account_address, AccountFieldTag.Balance, @@ -522,7 +522,7 @@ def sub_balance_with_reversion( rw_counter_end_of_reversion, state_write_counter, ) - result, carry = self.add_words([RLC(balance, self.randomness), *values]) + result, carry = self.add_words([balance, *values]) self.constrain_equal(balance_prev, result.value) self.constrain_zero(carry) return balance, balance_prev diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 75feb83aa..d8379d148 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import Mapping, Sequence, Set, List, TypeVar, Any, Type, Optional, Dict +from typing import Mapping, Sequence, Set, List, TypeVar, Any, Type, Optional, Dict, Union from enum import IntEnum, auto, Enum from itertools import chain, product from dataclasses import dataclass, field, asdict, fields - -from ..util import FQ, IntOrFQ +from ..util import FQ, IntOrFQ, RLC from .execution_state import ExecutionState from .opcode import ( invalid_opcodes, @@ -155,6 +154,9 @@ class RW(Enum): Read = False Write = True + def __int__(self): + return self.value + class RWTableTag(IntEnum): """ @@ -287,7 +289,7 @@ class BlockTableRow(TableRow): tag: BlockContextFieldTag # meaningful only for HistoryHash, will be zero for other tags block_number_or_zero: FQ - value: FQ + value: Union[FQ, RLC] @dataclass(frozen=True) @@ -296,7 +298,7 @@ class TxTableRow(TableRow): tag: TxContextFieldTag # meaningful only for CallData, will be zero for other tags call_data_index_or_zero: FQ - value: FQ + value: Union[FQ, RLC] @dataclass(frozen=True) @@ -339,15 +341,27 @@ class Tables: def __init__( self, - block_table: Set[BlockTableRow], - tx_table: Set[TxTableRow], - bytecode_table: Set[BytecodeTableRow], - rw_table: Set[RWTableRow], + block_table: Union[Set[Sequence[IntOrFQ]], Set[BlockTableRow]], + tx_table: Union[Set[Sequence[IntOrFQ]], Set[TxTableRow]], + bytecode_table: Union[Set[Sequence[IntOrFQ]], Set[BytecodeTableRow]], + rw_table: Union[Set[Sequence[IntOrFQ]], Set[RWTableRow]], ) -> None: - self.block_table = block_table - self.tx_table = tx_table - self.bytecode_table = bytecode_table - self.rw_table = rw_table + self.block_table = set( + row if isinstance(row, BlockTableRow) else BlockTableRow(*row) # type: ignore # (BlockTableRow input args) + for row in block_table + ) + self.tx_table = set( + row if isinstance(row, TxTableRow) else TxTableRow(*row) # type: ignore # (TxTableRow input args) + for row in tx_table + ) + self.bytecode_table = set( + row if isinstance(row, BytecodeTableRow) else BytecodeTableRow(*row) # type: ignore # (BytecodeTableRow input args) + for row in bytecode_table + ) + self.rw_table = set( + row if isinstance(row, RWTableRow) else RWTableRow(*row) # type: ignore # (RWTableRow input args) + for row in rw_table + ) def fixed_lookup( self, tag: FixedTableTag, value1: FQ, value2: FQ = None, value3: FQ = None @@ -394,9 +408,9 @@ def rw_lookup( ] query: Dict[str, Optional[IntOrFQ]] = { "rw_counter": rw_counter, - "rw": int(rw.value), - "tag": tag, - **zip(rest_keys, other_queries), + "is_write": int(rw.value), + "key1": tag, + **dict(zip(rest_keys, other_queries)), } return _lookup(RWTableRow, self.rw_table, query) diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index 5b29335e8..05bd6847f 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -65,19 +65,15 @@ def table_assignments(self, randomness: FQ) -> List[BlockTableRow]: return [ BlockTableRow(BlockContextFieldTag.Coinbase, FQ(0), FQ(self.coinbase)), BlockTableRow(BlockContextFieldTag.GasLimit, FQ(0), FQ(self.gas_limit)), - BlockTableRow(BlockContextFieldTag.Number, FQ(0), RLC(self.number, randomness).value), + BlockTableRow(BlockContextFieldTag.Number, FQ(0), RLC(self.number, randomness)), BlockTableRow(BlockContextFieldTag.Timestamp, FQ(0), FQ(self.timestamp)), - BlockTableRow( - BlockContextFieldTag.Difficulty, FQ(0), RLC(self.difficulty, randomness).value - ), - BlockTableRow( - BlockContextFieldTag.BaseFee, FQ(0), RLC(self.base_fee, randomness).value - ), + BlockTableRow(BlockContextFieldTag.Difficulty, FQ(0), RLC(self.difficulty, randomness)), + BlockTableRow(BlockContextFieldTag.BaseFee, FQ(0), RLC(self.base_fee, randomness)), ] + [ BlockTableRow( BlockContextFieldTag.HistoryHash, FQ(self.number - idx - 1), - RLC(history_hash, randomness).value, + RLC(history_hash, randomness), ) for idx, history_hash in enumerate(reversed(self.history_hashes)) ] @@ -136,19 +132,22 @@ def table_assignments(self, randomness: FQ) -> Iterator[TxTableRow]: FQ(self.id), TxContextFieldTag.GasPrice, FQ(0), - RLC(self.gas_price, randomness).value, + RLC(self.gas_price, randomness), ), TxTableRow( FQ(self.id), TxContextFieldTag.CallerAddress, FQ(0), FQ(self.caller_address) ), TxTableRow( - FQ(self.id), TxContextFieldTag.CalleeAddress, FQ(0), FQ(self.callee_address) + FQ(self.id), + TxContextFieldTag.CalleeAddress, + FQ(0), + FQ(self.callee_address or 0), ), TxTableRow( FQ(self.id), TxContextFieldTag.IsCreate, FQ(0), FQ(self.callee_address is None) ), TxTableRow( - FQ(self.id), TxContextFieldTag.Value, FQ(0), RLC(self.value, randomness).value + FQ(self.id), TxContextFieldTag.Value, FQ(0), RLC(self.value, randomness) ), TxTableRow( FQ(self.id), TxContextFieldTag.CallDataLength, FQ(0), FQ(len(self.call_data)) @@ -249,7 +248,7 @@ def __next__(self): self.idx += 1 - return (self.hash, idx, byte, is_code) + return (self.hash.value, idx, byte, is_code) return BytecodeIterator(RLC(self.hash(), randomness), self.code) diff --git a/src/zkevm_specs/util/arithmetic.py b/src/zkevm_specs/util/arithmetic.py index b199f71c1..9a7d1fe8e 100644 --- a/src/zkevm_specs/util/arithmetic.py +++ b/src/zkevm_specs/util/arithmetic.py @@ -62,5 +62,8 @@ def __hash__(self) -> int: def __repr__(self) -> str: return "RLC(%s)" % int.from_bytes(self.le_bytes, "little") + def __int__(self) -> int: + return int.from_bytes(self.le_bytes, "little") + def be_bytes(self) -> bytes: return bytes(reversed(self.le_bytes)) diff --git a/tests/evm/test_gasprice.py b/tests/evm/test_gasprice.py index 8b98d0fd9..c2af3cfe5 100644 --- a/tests/evm/test_gasprice.py +++ b/tests/evm/test_gasprice.py @@ -51,7 +51,18 @@ def test_gasprice(gasprice: U256): 0, 0, ), - (10, RW.Write, RWTableTag.Stack, 1, 1023, 0, RLC(gasprice, randomness), 0, 0, 0), + ( + 10, + RW.Write, + RWTableTag.Stack, + 1, + 1023, + 0, + RLC(gasprice, randomness).value, + 0, + 0, + 0, + ), ] ), ) diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index 835e1564c..c5f0148f7 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -92,13 +92,13 @@ def test_bytecode_invalid_hash_data(): # Change the hash on the first position invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[0] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash @@ -115,13 +115,13 @@ def test_bytecode_invalid_index(): # Start the index at 1 invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = (row[0].value + 1, row[1], row[2], row[3]) + invalid.rows[idx] = (row[0] + 1, row[1], row[2], row[3]) verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) invalid_cell = invalid.rows[-1][0] - invalid_cell.value -= 1 + invalid_cell -= 1 invalid.rows[-1] = (invalid_cell, row[1], row[2], row[3]) verify(k, [invalid], randomness, False)