diff --git a/specs/tables.md b/specs/tables.md index 15a388221..3af16a10e 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -149,8 +149,4 @@ Proved by the block circuit. | BitwiseAnd | lhs=0..256 | rhs=0..256 | $lhs AND $rhs | | BitwiseOr | lhs=0..256 | rhs=0..256 | $lhs OR $rhs | | BitwiseXor | lhs=0..256 | rhs=0..256 | $lhs XOR $rhs | -| ResponsibleOpcode | $execution_state | execution_state.responsible_opcode() | 0 | -| InvalidOpcode | invalid_opcodes() | 0 | 0 | -| StateWriteOpcode | state_write_opcodes() | 0 | 0 | -| StackOverflow | $overflow_opcode | stack_overflow_pairs\[overflow_opcode\] | 0 | -| StackUnderflow | $underflow_opcode | stack_underflow_pairs\[underflow_opcode\] | 0 | +| ResponsibleOpcode | $execution_state | $responsible_opcode | $auxiliary | diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index cbdc92397..f57b80769 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -1,7 +1,13 @@ from enum import IntEnum, auto -from typing import Sequence +from typing import Sequence, Tuple, Union -from .opcode import Opcode +from .opcode import ( + Opcode, + invalid_opcodes, + state_write_opcodes, + stack_underflow_pairs, + stack_overflow_pairs, +) class ExecutionState(IntEnum): @@ -74,7 +80,7 @@ class ExecutionState(IntEnum): PUSH = auto() # PUSH1, PUSH2, ..., PUSH32 DUP = auto() # DUP1, DUP2, ..., DUP16 SWAP = auto() # SWAP1, SWAP2, ..., SWAP16 - LOG = auto() # LOG1, LOG2, ..., LOG5 + LOG = auto() # LOG0, LOG1, LOG2, LOG3, LOG4 CREATE = auto() CALL = auto() CALLCODE = auto() @@ -87,10 +93,9 @@ class ExecutionState(IntEnum): # Error cases ErrorInvalidOpcode = auto() - # For opcodes which push more than pop - ErrorStackOverflow = auto() - # For opcodes which pop and DUP, SWAP which peek deeper element directly - ErrorStackUnderflow = auto() + # For opcodes which triggers stackoverflow by doing push more than pop, + # or stackunderflow by doing pop and DUP, SWAP which peek deeper element directly + ErrorStack = auto() # For SSTORE, LOG0, LOG1, LOG2, LOG3, LOG4, CREATE, CALL, CREATE2, SELFDESTRUCT ErrorWriteProtection = auto() # For CALL, CALLCODE, DELEGATECALL, STATICCALL @@ -102,34 +107,40 @@ class ExecutionState(IntEnum): ErrorInvalidCreationCode = auto() # For opcode RETURN which needs to store code when it's is creation ErrorMaxCodeSizeExceeded = auto() - # For REVERT - ErrorReverted = auto() # For JUMP, JUMPI ErrorInvalidJump = auto() # For RETURNDATACOPY ErrorReturnDataOutOfBound = auto() # For opcodes which have non-zero constant gas cost ErrorOutOfGasConstant = auto() - # For opcodes MLOAD, MSTORE, MSTORE8, CREATE, RETURN, REVERT, which have pure memory expansion gas cost - ErrorOutOfGasPureMemory = auto() + # For opcodes MLOAD, MSTORE, MSTORE8, which have static size memory expansion gas cost + ErrorOutOfGasStaticMemoryExpansion = auto() + # For opcodes CREATE, RETURN, REVERT, which have dynamic size memory expansion gas cost + ErrorOutOfGasDynamicMemoryExpansion = auto() + # For opcode CALLDATACOPY, CODECOPY, RETURNDATACOPY, which copies a specified chunk of memory + ErrorOutOfGasMemoryCopy = auto() + # For opcodes BALANCE, EXTCODESIZE, EXTCODEHASH, which possibly touches an extra account + ErrorOutOfGasAccountAccess = auto() # For opcode RETURN which has code storing gas cost when it's is creation ErrorOutOfGasCodeStore = auto() - # For opcodes which have dynamic gas usage rather than pure memory expansion + # For opcodes LOG0, LOG1, LOG2, LOG3, LOG4 + ErrorOutOfGasLOG = auto() + # For opcodes which have their own gas calculation + ErrorOutOfGasEXP = auto() ErrorOutOfGasSHA3 = auto() - ErrorOutOfGasCALLDATACOPY = auto() - ErrorOutOfGasCODECOPY = auto() ErrorOutOfGasEXTCODECOPY = auto() - ErrorOutOfGasRETURNDATACOPY = auto() - ErrorOutOfGasLOG = auto() + ErrorOutOfGasSLOAD = auto() + ErrorOutOfGasSSTORE = auto() ErrorOutOfGasCALL = auto() ErrorOutOfGasCALLCODE = auto() ErrorOutOfGasDELEGATECALL = auto() ErrorOutOfGasCREATE2 = auto() ErrorOutOfGasSTATICCALL = auto() + ErrorOutOfGasSELFDESTRUCT = auto() # TODO: Precompile success and error cases - def responsible_opcode(self) -> Sequence[Opcode]: + def responsible_opcode(self) -> Union[Sequence[int], Sequence[Tuple[int, int]]]: if self == ExecutionState.STOP: return [Opcode.STOP] elif self == ExecutionState.ADD: @@ -359,4 +370,52 @@ def responsible_opcode(self) -> Sequence[Opcode]: return [Opcode.REVERT] elif self == ExecutionState.SELFDESTRUCT: return [Opcode.SELFDESTRUCT] + elif self == ExecutionState.ErrorInvalidOpcode: + return invalid_opcodes() + elif self == ExecutionState.ErrorStack: + return stack_overflow_pairs() + stack_underflow_pairs() + elif self == ExecutionState.ErrorWriteProtection: + return state_write_opcodes() return [] + + def halts(self): + return self.halts_in_success() or self.halts_in_exception() or self == ExecutionState.REVERT + + def halts_in_success(self): + return self in [ + ExecutionState.STOP, + ExecutionState.RETURN, + ExecutionState.SELFDESTRUCT, + ] + + def halts_in_exception(self): + return self in [ + ExecutionState.ErrorInvalidOpcode, + ExecutionState.ErrorStack, + ExecutionState.ErrorWriteProtection, + ExecutionState.ErrorDepth, + ExecutionState.ErrorInsufficientBalance, + ExecutionState.ErrorContractAddressCollision, + ExecutionState.ErrorInvalidCreationCode, + ExecutionState.ErrorMaxCodeSizeExceeded, + ExecutionState.ErrorInvalidJump, + ExecutionState.ErrorReturnDataOutOfBound, + ExecutionState.ErrorOutOfGasConstant, + ExecutionState.ErrorOutOfGasStaticMemoryExpansion, + ExecutionState.ErrorOutOfGasDynamicMemoryExpansion, + ExecutionState.ErrorOutOfGasMemoryCopy, + ExecutionState.ErrorOutOfGasAccountAccess, + ExecutionState.ErrorOutOfGasCodeStore, + ExecutionState.ErrorOutOfGasLOG, + ExecutionState.ErrorOutOfGasEXP, + ExecutionState.ErrorOutOfGasSHA3, + ExecutionState.ErrorOutOfGasEXTCODECOPY, + ExecutionState.ErrorOutOfGasSLOAD, + ExecutionState.ErrorOutOfGasSSTORE, + ExecutionState.ErrorOutOfGasCALL, + ExecutionState.ErrorOutOfGasCALLCODE, + ExecutionState.ErrorOutOfGasDELEGATECALL, + ExecutionState.ErrorOutOfGasCREATE2, + ExecutionState.ErrorOutOfGasSTATICCALL, + ExecutionState.ErrorOutOfGasSELFDESTRUCT, + ] diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index d37b1c4a8..9ef1511b0 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -16,6 +16,7 @@ MEMORY_EXPANSION_QUAD_DENOMINATOR, MEMORY_EXPANSION_LINEAR_COEFF, ) +from .execution_state import ExecutionState from .opcode import Opcode from .step import StepState from .table import ( @@ -107,7 +108,26 @@ def constrain_bool(self, num: FQ): def constrain_gas_left_not_underflow(self, gas_left: FQ): self.range_check(gas_left, N_BYTES_GAS) - def constrain_step_state_transition(self, **kwargs: Mapping[str, Transition]): + def constrain_execution_state_transition(self): + curr, next = self.curr.execution_state, self.next.execution_state + + # ExecutionState transition constraint for special ones + if curr == ExecutionState.EndTx: + assert next in [ExecutionState.BeginTx, ExecutionState.EndBlock] + elif curr == ExecutionState.EndBlock: + assert next == ExecutionState.EndBlock + + # Negation ExecutionState transition constraint for rest ones + if next == ExecutionState.BeginTx: + assert curr == ExecutionState.EndTx + elif next == ExecutionState.EndTx: + assert curr.halts() or curr == ExecutionState.BeginTx + elif next == ExecutionState.EndBlock: + assert curr in [ExecutionState.EndTx, ExecutionState.EndBlock] + elif next == ExecutionState.CopyToMemory: + assert curr in [ExecutionState.CopyToMemory, ExecutionState.CALLDATACOPY] + + def constrain_step_state_transition(self, **kwargs: Transition): keys = set( [ "rw_counter", diff --git a/src/zkevm_specs/evm/main.py b/src/zkevm_specs/evm/main.py index f990f9100..8a9e6f4c7 100644 --- a/src/zkevm_specs/evm/main.py +++ b/src/zkevm_specs/evm/main.py @@ -15,24 +15,26 @@ def verify_steps( begin_with_first_step: bool = False, end_with_last_step: bool = False, ): - # TODO: Enforce general ExecutionState transition constraint + # For the last step, the next step is meaningless + if end_with_last_step: + steps += [None] + + for idx in range(len(steps) - 1): + curr, next = steps[idx], steps[idx + 1] - for idx in range(len(steps) if end_with_last_step else len(steps) - 1): verify_step( Instruction( randomness=randomness, tables=tables, - curr=steps[idx], - next=steps[idx + 1] if idx + 1 < len(steps) else None, + curr=curr, + next=next, is_first_step=begin_with_first_step and idx == 0, - is_last_step=idx + 1 == len(steps), - ), + is_last_step=end_with_last_step and next is None, + ) ) -def verify_step( - instruction: Instruction, -): +def verify_step(instruction: Instruction): if instruction.is_first_step: instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.BeginTx) @@ -43,3 +45,5 @@ def verify_step( if instruction.is_last_step: instruction.constrain_equal(instruction.curr.execution_state, ExecutionState.EndBlock) + else: + instruction.constrain_execution_state_transition() diff --git a/src/zkevm_specs/evm/opcode.py b/src/zkevm_specs/evm/opcode.py index 8f58cdc7e..6b7641aa7 100644 --- a/src/zkevm_specs/evm/opcode.py +++ b/src/zkevm_specs/evm/opcode.py @@ -1,5 +1,6 @@ from enum import IntEnum -from typing import Final, Dict, Sequence, Tuple, Union +from typing import Final, Dict, List, Tuple + from ..util.param import * @@ -184,11 +185,6 @@ class OpcodeInfo: max_stack_pointer: int constant_gas_cost: int has_dynamic_gas: bool - pure_memory_expansion_info: Tuple[ - int, # offset stack_pointer_offset - int, # length stack_pointer_offset - int, # constant length - ] def __init__( self, @@ -196,13 +192,11 @@ def __init__( max_stack_pointer: int, constant_gas_cost: int, has_dynamic_gas: bool = False, - pure_memory_expansion_info: Union[Tuple[int, int, int], None] = None, ) -> None: self.min_stack_pointer = min_stack_pointer self.max_stack_pointer = max_stack_pointer self.constant_gas_cost = constant_gas_cost self.has_dynamic_gas = has_dynamic_gas - self.pure_memory_expansion_info = pure_memory_expansion_info OPCODE_INFO_MAP: Final[Dict[Opcode, OpcodeInfo]] = dict( @@ -260,9 +254,9 @@ def __init__( Opcode.SELFBALANCE: OpcodeInfo(1, 1024, GAS_COST_FAST), Opcode.BASEFEE: OpcodeInfo(1, 1024, GAS_COST_QUICK), Opcode.POP: OpcodeInfo(-1, 1023, GAS_COST_QUICK), - Opcode.MLOAD: OpcodeInfo(0, 1023, GAS_COST_FASTEST, True, (0, 0, 32)), - Opcode.MSTORE: OpcodeInfo(-2, 1022, GAS_COST_FASTEST, True, (0, 0, 32)), - Opcode.MSTORE8: OpcodeInfo(-2, 1022, GAS_COST_FASTEST, True, (0, 0, 1)), + Opcode.MLOAD: OpcodeInfo(0, 1023, GAS_COST_FASTEST, True), + Opcode.MSTORE: OpcodeInfo(-2, 1022, GAS_COST_FASTEST, True), + Opcode.MSTORE8: OpcodeInfo(-2, 1022, GAS_COST_FASTEST, True), Opcode.SLOAD: OpcodeInfo(0, 1023, GAS_COST_ZERO, True), Opcode.SSTORE: OpcodeInfo(-2, 1022, GAS_COST_ZERO, True), Opcode.JUMP: OpcodeInfo(-1, 1023, GAS_COST_MID), @@ -340,28 +334,28 @@ def __init__( Opcode.LOG2: OpcodeInfo(-4, 1020, GAS_COST_ZERO, True), Opcode.LOG3: OpcodeInfo(-5, 1019, GAS_COST_ZERO, True), Opcode.LOG4: OpcodeInfo(-6, 1018, GAS_COST_ZERO, True), - Opcode.CREATE: OpcodeInfo(-2, 1021, GAS_COST_CREATE, True, (1, 2, 0)), + Opcode.CREATE: OpcodeInfo(-2, 1021, GAS_COST_CREATE, True), Opcode.CALL: OpcodeInfo(-6, 1017, GAS_COST_WARM_ACCESS, True), Opcode.CALLCODE: OpcodeInfo(-6, 1017, GAS_COST_WARM_ACCESS, True), - Opcode.RETURN: OpcodeInfo(-2, 1022, GAS_COST_ZERO, True, (0, 1, 0)), + Opcode.RETURN: OpcodeInfo(-2, 1022, GAS_COST_ZERO, True), Opcode.DELEGATECALL: OpcodeInfo(-5, 1018, GAS_COST_WARM_ACCESS, True), Opcode.CREATE2: OpcodeInfo(-3, 1020, GAS_COST_CREATE2, True), Opcode.STATICCALL: OpcodeInfo(-5, 1018, GAS_COST_WARM_ACCESS, True), - Opcode.REVERT: OpcodeInfo(-2, 1022, GAS_COST_ZERO, True, (0, 1, 0)), + Opcode.REVERT: OpcodeInfo(-2, 1022, GAS_COST_ZERO, True), Opcode.SELFDESTRUCT: OpcodeInfo(-1, 1023, GAS_COST_SELF_DESTRUCT, True), } ) -def valid_opcodes() -> Sequence[Opcode]: +def valid_opcodes() -> List[Opcode]: return list(Opcode) -def invalid_opcodes() -> Sequence[int]: +def invalid_opcodes() -> List[int]: return [opcode for opcode in range(256) if opcode not in valid_opcodes()] -def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: +def stack_overflow_pairs() -> List[Tuple[Opcode, int]]: pairs = [] for opcode in valid_opcodes(): if opcode.min_stack_pointer() > 0: @@ -370,7 +364,7 @@ def stack_overflow_pairs() -> Sequence[Tuple[int, int]]: return pairs -def stack_underflow_pairs() -> Sequence[Tuple[int, int]]: +def stack_underflow_pairs() -> List[Tuple[Opcode, int]]: pairs = [] for opcode in valid_opcodes(): if opcode.max_stack_pointer() < 1024: @@ -379,7 +373,7 @@ def stack_underflow_pairs() -> Sequence[Tuple[int, int]]: return pairs -def opcode_constant_gas_cost_pairs() -> Sequence[Tuple[int, int]]: +def constant_gas_cost_pairs() -> List[Tuple[Opcode, int]]: pairs = [] for opcode in valid_opcodes(): if not opcode.has_dynamic_gas() and opcode.constant_gas_cost() > 0: @@ -387,7 +381,7 @@ def opcode_constant_gas_cost_pairs() -> Sequence[Tuple[int, int]]: return pairs -def state_write_opcodes() -> Sequence[int]: +def state_write_opcodes() -> List[Opcode]: return [ Opcode.SSTORE, Opcode.LOG0, @@ -402,19 +396,19 @@ def state_write_opcodes() -> Sequence[int]: ] -def call_opcodes() -> Sequence[int]: +def call_opcodes() -> List[Opcode]: return [Opcode.CALL, Opcode.CALLCODE, Opcode.DELEGATECALL, Opcode.STATICCALL] -def ether_transfer_opcdes() -> Sequence[int]: +def ether_transfer_opcdes() -> List[Opcode]: return [Opcode.CALL, Opcode.CALLCODE] -def create_opcodes() -> Sequence[int]: +def create_opcodes() -> List[Opcode]: return [Opcode.CREATE, Opcode.CREATE2] -def jump_opcodes() -> Sequence[int]: +def jump_opcodes() -> List[Opcode]: return [Opcode.JUMP, Opcode.JUMPI] diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 8182a4ccc..cd7e42867 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -5,12 +5,6 @@ from ..util import FQ, RLC, Array3, Array4, Array10 from .execution_state import ExecutionState -from .opcode import ( - invalid_opcodes, - state_write_opcodes, - stack_underflow_pairs, - stack_overflow_pairs, -) class Placeholder: @@ -34,11 +28,7 @@ class FixedTableTag(IntEnum): BitwiseAnd = auto() # lhs, rhs, lhs & rhs, 0 BitwiseOr = auto() # lhs, rhs, lhs | rhs, 0 BitwiseXor = auto() # lhs, rhs, lhs ^ rhs, 0 - ResponsibleOpcode = auto() # execution_state, opcode, 0 - InvalidOpcode = auto() # opcode, 0, 0 - StateWriteOpcode = auto() # opcode, 0, 0 - StackOverflow = auto() # opcode, stack_pointer, 0 - StackUnderflow = auto() # opcode, stack_pointer, 0 + ResponsibleOpcode = auto() # execution_state, opcode, aux def table_assignments(self) -> Sequence[Array4]: if self == FixedTableTag.Range16: @@ -63,25 +53,15 @@ def table_assignments(self) -> Sequence[Array4]: return [(self, lhs, rhs, lhs ^ rhs) for lhs, rhs in product(range(256), range(256))] elif self == FixedTableTag.ResponsibleOpcode: return [ - (self, execution_state, opcode, 0) + (self, execution_state, opcode, aux) for execution_state in list(ExecutionState) - for opcode in execution_state.responsible_opcode() - ] - elif self == FixedTableTag.InvalidOpcode: - return [(self, opcode, 0, 0) for opcode in invalid_opcodes()] - elif self == FixedTableTag.StateWriteOpcode: - return [(self, opcode, 0, 0) for opcode in state_write_opcodes()] - elif self == FixedTableTag.StackOverflow: - return [ - (self, opcode, stack_pointer, 0) - for opcode, stack_pointer in stack_underflow_pairs() - ] - elif self == FixedTableTag.StackUnderflow: - return [ - (self, opcode, stack_pointer, 0) for opcode, stack_pointer in stack_overflow_pairs() + for opcode, aux in map( + lambda pair: pair if isinstance(pair, tuple) else (pair, 0), + execution_state.responsible_opcode(), + ) ] else: - ValueError("Unreacheable") + raise ValueError("Unreacheable") def range_table_tag(range: int) -> FixedTableTag: if range == 16: