From 7ad9d96d554d474d610afeefdb601e4874aae325 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Thu, 26 Jan 2023 18:31:52 +0800 Subject: [PATCH] Handle OOG of callcode, delegatecall and staticcall in `oog_call`. --- specs/error_state/ErrorOutOfGasCall.md | 11 +- src/zkevm_specs/evm/execution/__init__.py | 2 +- src/zkevm_specs/evm/execution/oog_call.py | 20 +- src/zkevm_specs/evm/execution_state.py | 11 +- tests/evm/test_oog_call.py | 289 ++++++++++++---------- 5 files changed, 190 insertions(+), 143 deletions(-) diff --git a/specs/error_state/ErrorOutOfGasCall.md b/specs/error_state/ErrorOutOfGasCall.md index 3c61bed87..7ff5d8da3 100644 --- a/specs/error_state/ErrorOutOfGasCall.md +++ b/specs/error_state/ErrorOutOfGasCall.md @@ -1,15 +1,19 @@ # ErrorOutOfGasCall state ## Procedure + +Handle the corresponding out of gas errors for `CALL`, `CALLCODE`, `DELEGATECALL` and `STATICCALL` opcodes. + ### EVM behavior + For this gadget, the core is to calculate gas required, there are multiple kinds of gas -consumes in call: +consumes in above call related opcodes: 1. memory expansion gas cost 2. gas cost if new account creates 3. transfer fee if has value -4. account access list cost(warm/cold) +4. account access list cost (warm/cold) -below is the total gas cost calculation which from [`Call` Spec](../opcode/F1CALL_F4DELEGATECALL_FASTATICCALL.md). +below is the total gas cost calculation which from [Spec of call related opcodes](../opcode/F1CALL_F4DELEGATECALL_FASTATICCALL.md). ``` GAS_COST_WARM_ACCESS := 100 GAS_COST_ACCOUNT_COLD_ACCESS := 2600 @@ -23,6 +27,7 @@ gas_cost = ( ``` ### Constraints + 1. `gas_left < gas_cost`. 2. Current call must be failed. 3. If it's a root call, it transits to `EndTx`. diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 743b18f79..252eece90 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -108,7 +108,7 @@ ExecutionState.RETURN: return_revert, ExecutionState.ErrorOutOfGasConstant: oog_constant, ExecutionState.ErrorInvalidJump: invalid_jump, - ExecutionState.ErrorOutOfGasCALL: oog_call, + ExecutionState.ErrorOutOfGasCall: oog_call, ExecutionState.ErrorStack: stack_error, # ExecutionState.ECRECOVER: , # ExecutionState.SHA256: , diff --git a/src/zkevm_specs/evm/execution/oog_call.py b/src/zkevm_specs/evm/execution/oog_call.py index 9eb2d18d0..f40ef46af 100644 --- a/src/zkevm_specs/evm/execution/oog_call.py +++ b/src/zkevm_specs/evm/execution/oog_call.py @@ -7,16 +7,22 @@ from ..opcode import Opcode +# Handle the corresponding out of gas errors for CALL, CALLCODE, DELEGATECALL +# and STATICCALL opcodes. def oog_call(instruction: Instruction): # retrieve op code associated to oog call error opcode = instruction.opcode_lookup(True) - # TODO: add CallCode etc.when handle ErrorOutOfGasCALLCODE in future implementation - instruction.constrain_equal(opcode, Opcode.CALL) + is_call, is_callcode, is_delegatecall, is_staticcall = instruction.multiple_select( + opcode, (Opcode.CALL, Opcode.CALLCODE, Opcode.DELEGATECALL, Opcode.STATICCALL) + ) + + # Constrain opcode must be CALL, CALLCODE, DELEGATECALL or STATICCALL. + instruction.constrain_equal(is_call + is_callcode + is_delegatecall + is_staticcall, FQ(1)) tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId) # init CallGadget to handle stack vars. - call = CallGadget(instruction, FQ(0), FQ(1), FQ(0), FQ(0)) + call = CallGadget(instruction, FQ(0), is_call, is_callcode, is_delegatecall) # TODO: handle PrecompiledContract oog cases @@ -39,18 +45,22 @@ def oog_call(instruction: Instruction): is_to_end_tx = instruction.is_equal(instruction.next.execution_state, ExecutionState.EndTx) instruction.constrain_equal(FQ(instruction.curr.is_root), is_to_end_tx) + # Both CALL and CALLCODE opcodes have an extra stack pop `value` relative to + # DELEGATECALL and STATICCALL. + rw_counter_delta = 11 + is_call + is_callcode + # state transition. if instruction.curr.is_root: # Do step state transition instruction.constrain_step_state_transition( - rw_counter=Transition.delta(12), + rw_counter=Transition.delta(rw_counter_delta), call_id=Transition.same(), ) else: # when it is internal call, need to restore caller's state as finishing this call. # Restore caller state to next StepState instruction.step_state_transition_to_restored_context( - rw_counter_delta=12, + rw_counter_delta=rw_counter_delta.n, return_data_offset=FQ(0), return_data_length=FQ(0), gas_left=instruction.curr.gas_left, diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index d23ce1a93..2565dba28 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -118,11 +118,9 @@ class ExecutionState(IntEnum): ErrorOutOfGasEXTCODECOPY = auto() ErrorOutOfGasSLOAD = auto() ErrorOutOfGasSSTORE = auto() - ErrorOutOfGasCALL = auto() - ErrorOutOfGasCALLCODE = auto() - ErrorOutOfGasDELEGATECALL = auto() + # For CALL, CALLCODE, DELEGATECALL and STATICCALL opcodes which may run out of gas. + ErrorOutOfGasCall = auto() ErrorOutOfGasCREATE2 = auto() - ErrorOutOfGasSTATICCALL = auto() ErrorOutOfGasSELFDESTRUCT = auto() # Precompile's successful cases @@ -393,11 +391,8 @@ def halts_in_exception(self) -> bool: ExecutionState.ErrorOutOfGasEXTCODECOPY, ExecutionState.ErrorOutOfGasSLOAD, ExecutionState.ErrorOutOfGasSSTORE, - ExecutionState.ErrorOutOfGasCALL, - ExecutionState.ErrorOutOfGasCALLCODE, - ExecutionState.ErrorOutOfGasDELEGATECALL, + ExecutionState.ErrorOutOfGasCall, ExecutionState.ErrorOutOfGasCREATE2, - ExecutionState.ErrorOutOfGasSTATICCALL, ExecutionState.ErrorOutOfGasSELFDESTRUCT, ] diff --git a/tests/evm/test_oog_call.py b/tests/evm/test_oog_call.py index 063a0295e..1ff97ccf9 100644 --- a/tests/evm/test_oog_call.py +++ b/tests/evm/test_oog_call.py @@ -1,76 +1,128 @@ -import itertools import pytest from collections import namedtuple -from itertools import chain - +from itertools import chain, product from zkevm_specs.evm import ( - ExecutionState, - StepState, - verify_steps, - Tables, + Account, AccountFieldTag, - CallContextFieldTag, Block, - Account, Bytecode, + CallContextFieldTag, + ExecutionState, + Opcode, RWDictionary, + StepState, + Tables, + verify_steps, ) -from zkevm_specs.util import rand_fq, RLC, EMPTY_CODE_HASH -from zkevm_specs.util.param import ( - GAS_COST_NEW_ACCOUNT, +from zkevm_specs.util import ( + EMPTY_CODE_HASH, + GAS_COST_ACCOUNT_COLD_ACCESS, GAS_COST_CALL_WITH_VALUE, + GAS_COST_NEW_ACCOUNT, GAS_COST_WARM_ACCESS, - GAS_COST_ACCOUNT_COLD_ACCESS, GAS_STIPEND_CALL_WITH_VALUE, + RLC, + rand_fq, ) - CallContext = namedtuple( "CallContext", [ - "rw_counter_end_of_reversion", - "is_persistent", "gas_left", "memory_size", "reversible_write_counter", ], - defaults=[0, False, 0, 0, 2], + defaults=[0, 0, 2], ) + Stack = namedtuple( "Stack", ["gas", "value", "cd_offset", "cd_length", "rd_offset", "rd_length"], - defaults=[0, 0, 0, 0, 0, 0], + defaults=[100, 0, 64, 320, 0, 32], ) -STOP_BYTECODE = Bytecode().stop() -CALLER = Account(address=0xFE, balance=int(1e20)) -CALLEE_WITH_STOP_BYTECODE_AND_BALANCE = Account(address=0xFF, code=STOP_BYTECODE, balance=int(1e18)) +def call_bytecode(opcode: Opcode, stack: Stack, callee: Account) -> Bytecode: + if opcode == Opcode.CALL: + bytecode = ( + Bytecode() + .call( + stack.gas, + callee.address, + stack.value, + stack.cd_offset, + stack.cd_length, + stack.rd_offset, + stack.rd_length, + ) + .stop() + ) + elif opcode == Opcode.CALLCODE: + bytecode = ( + Bytecode() + .callcode( + stack.gas, + callee.address, + stack.value, + stack.cd_offset, + stack.cd_length, + stack.rd_offset, + stack.rd_length, + ) + .stop() + ) + elif opcode == Opcode.DELEGATECALL: + bytecode = ( + Bytecode() + .delegatecall( + stack.gas, + callee.address, + stack.cd_offset, + stack.cd_length, + stack.rd_offset, + stack.rd_length, + ) + .stop() + ) + elif opcode == Opcode.STATICCALL: + bytecode = ( + Bytecode() + .staticcall( + stack.gas, + callee.address, + stack.cd_offset, + stack.cd_length, + stack.rd_offset, + stack.rd_length, + ) + .stop() + ) + else: + raise Exception("unreachable") + + return bytecode def gen_testing_data(): - callees = [ - CALLEE_WITH_STOP_BYTECODE_AND_BALANCE, - ] - call_contexts = [ - CallContext(gas_left=50, is_persistent=False), - CallContext(gas_left=100, is_persistent=False, rw_counter_end_of_reversion=0), - ] + callee = Account(address=0xFF, code=Bytecode().stop(), balance=int(1e18)) + call_opcodes = [Opcode.CALL, Opcode.CALLCODE, Opcode.DELEGATECALL, Opcode.STATICCALL] + call_contexts = [CallContext(gas_left=50), CallContext(gas_left=100)] stacks = [ Stack(gas=100, cd_offset=64, cd_length=320, rd_offset=0, rd_length=32), ] - is_warm_accesss = [True, False] + is_warm_accesses = [True, False] return [ ( - CALLER, callee, - call_context, + call_bytecode(opcode, stack, callee), + caller_context, stack, + opcode in [Opcode.CALL, Opcode.CALLCODE], is_warm_access, ) - for callee, call_context, stack, is_warm_access in itertools.product( - callees, call_contexts, stacks, is_warm_accesss + for opcode, caller_context, stack, is_warm_access in product( + call_opcodes, call_contexts, stacks, is_warm_accesses ) ] @@ -78,50 +130,43 @@ def gen_testing_data(): TESTING_DATA = gen_testing_data() -@pytest.mark.parametrize("caller, callee, caller_ctx, stack, is_warm_access", TESTING_DATA) -def test_root_call( - caller: Account, +@pytest.mark.parametrize( + "callee, caller_bytecode, caller_context, stack, has_value, is_warm_access", TESTING_DATA +) +def test_oog_call_root( callee: Account, - caller_ctx: CallContext, + caller_bytecode: Bytecode, + caller_context: CallContext, stack: Stack, + has_value: bool, is_warm_access: bool, ): randomness = rand_fq() - caller_bytecode = ( - Bytecode() - .call( - stack.gas, - callee.address, - stack.value, - stack.cd_offset, - stack.cd_length, - stack.rd_offset, - stack.rd_length, - ) - .stop() - ) caller_bytecode_hash = RLC(caller_bytecode.hash(), randomness) callee_bytecode_hash = RLC(callee.code_hash(), randomness) is_success = False + program_counter = 231 if has_value else 198 - # fmt: off rw_dictionary = ( RWDictionary(24) .call_context_read(1, CallContextFieldTag.TxId, 1) - .stack_read(1, 1017, RLC(stack.gas, randomness)) - .stack_read(1, 1018, RLC(callee.address, randomness)) - .stack_read(1, 1019, RLC(stack.value, randomness)) - .stack_read(1, 1020, RLC(stack.cd_offset, randomness)) - .stack_read(1, 1021, RLC(stack.cd_length, randomness)) - .stack_read(1, 1022, RLC(stack.rd_offset, randomness)) - .stack_read(1, 1023, RLC(stack.rd_length, randomness)) - .stack_write(1, 1023, RLC(is_success, randomness)) - .account_read(callee.address, AccountFieldTag.CodeHash, callee_bytecode_hash) - .tx_access_list_account_read(1, callee.address, is_warm_access) - .call_context_read(1, CallContextFieldTag.IsSuccess, 0) + .stack_read(1, 1018 - has_value, RLC(stack.gas, randomness)) + .stack_read(1, 1019 - has_value, RLC(callee.address, randomness)) ) - # fmt: on + if has_value: + rw_dictionary.stack_read(1, 1019, RLC(stack.value, randomness)) + # fmt: off + rw_dictionary \ + .stack_read(1, 1020, RLC(stack.cd_offset, randomness)) \ + .stack_read(1, 1021, RLC(stack.cd_length, randomness)) \ + .stack_read(1, 1022, RLC(stack.rd_offset, randomness)) \ + .stack_read(1, 1023, RLC(stack.rd_length, randomness)) \ + .stack_write(1, 1023, RLC(is_success, randomness)) \ + .account_read(callee.address, AccountFieldTag.CodeHash, callee_bytecode_hash) \ + .tx_access_list_account_read(1, callee.address, is_warm_access) \ + .call_context_read(1, CallContextFieldTag.IsSuccess, 0) + # fmt on tables = Tables( block_table=set(Block().table_assignments(randomness)), @@ -140,17 +185,17 @@ def test_root_call( tables=tables, steps=[ StepState( - execution_state=ExecutionState.ErrorOutOfGasCALL, + execution_state=ExecutionState.ErrorOutOfGasCall, rw_counter=24, call_id=1, is_root=True, is_create=False, code_hash=caller_bytecode_hash, - program_counter=231, - stack_pointer=1017, - gas_left=caller_ctx.gas_left, - memory_size=caller_ctx.memory_size, - reversible_write_counter=caller_ctx.reversible_write_counter, + program_counter=program_counter, + stack_pointer=1018 - has_value, + gas_left=caller_context.gas_left, + memory_size=caller_context.memory_size, + reversible_write_counter=caller_context.reversible_write_counter, ), StepState( execution_state=ExecutionState.EndTx, @@ -162,65 +207,57 @@ def test_root_call( ) -CallerContext = namedtuple( - "CallerContext", - [ - "is_root", - "is_create", - "program_counter", - "stack_pointer", - "gas_left", - "memory_size", - "reversible_write_counter", - ], - defaults=[False, False, 232, 1023, 10, 0, 0], +@pytest.mark.parametrize( + "callee, caller_bytecode, caller_context, stack, has_value, is_warm_access", TESTING_DATA ) - -TESTING_DATA_NOT_ROOT = ((CallerContext(), CALLEE_WITH_STOP_BYTECODE_AND_BALANCE),) - - -@pytest.mark.parametrize("caller_ctx, callee", TESTING_DATA_NOT_ROOT) -def test_oog_call_not_root(caller_ctx: CallerContext, callee: Account): +def test_oog_call_not_root( + callee: Account, + caller_bytecode: Bytecode, + caller_context: CallContext, + stack: Stack, + has_value: bool, + is_warm_access: bool, +): randomness = rand_fq() - caller_bytecode = Bytecode().call(0, 0xFF, 0, 0, 0, 0, 0).stop() caller_bytecode_hash = RLC(caller_bytecode.hash(), randomness) callee_bytecode_hash = RLC(callee.code_hash(), randomness) callee_reversible_write_counter = 0 - stack = Stack(gas=100, cd_offset=64, cd_length=320, rd_offset=0, rd_length=32) + is_success = False + program_counter = 231 if has_value else 198 - is_warm_access = False rw_dictionary = ( RWDictionary(24) .call_context_read(2, CallContextFieldTag.TxId, 1) - .stack_read(2, 1017, RLC(stack.gas, randomness)) - .stack_read(2, 1018, RLC(callee.address, randomness)) - .stack_read(2, 1019, RLC(stack.value, randomness)) - .stack_read(2, 1020, RLC(stack.cd_offset, randomness)) - .stack_read(2, 1021, RLC(stack.cd_length, randomness)) - .stack_read(2, 1022, RLC(stack.rd_offset, randomness)) - .stack_read(2, 1023, RLC(stack.rd_length, randomness)) - .stack_write(2, 1023, RLC(False, randomness)) - .account_read(callee.address, AccountFieldTag.CodeHash, callee_bytecode_hash) - .tx_access_list_account_read(1, callee.address, is_warm_access) - .call_context_read(2, CallContextFieldTag.IsSuccess, 0) - # restore context operations - .call_context_read(2, CallContextFieldTag.CallerId, 1) - .call_context_read(1, CallContextFieldTag.IsRoot, caller_ctx.is_root) - .call_context_read(1, CallContextFieldTag.IsCreate, caller_ctx.is_create) - .call_context_read(1, CallContextFieldTag.CodeHash, caller_bytecode_hash) - .call_context_read(1, CallContextFieldTag.ProgramCounter, caller_ctx.program_counter) - .call_context_read(1, CallContextFieldTag.StackPointer, caller_ctx.stack_pointer) - .call_context_read(1, CallContextFieldTag.GasLeft, caller_ctx.gas_left) - .call_context_read(1, CallContextFieldTag.MemorySize, caller_ctx.memory_size) - .call_context_read( - 1, CallContextFieldTag.ReversibleWriteCounter, caller_ctx.reversible_write_counter - ) - .call_context_write(1, CallContextFieldTag.LastCalleeId, 2) - .call_context_write(1, CallContextFieldTag.LastCalleeReturnDataOffset, 0) - .call_context_write(1, CallContextFieldTag.LastCalleeReturnDataLength, 0) + .stack_read(2, 1018 - has_value, RLC(stack.gas, randomness)) + .stack_read(2, 1019 - has_value, RLC(callee.address, randomness)) ) + if has_value: + rw_dictionary.stack_read(2, 1019, RLC(stack.value, randomness)) + # fmt: off + rw_dictionary \ + .stack_read(2, 1020, RLC(stack.cd_offset, randomness)) \ + .stack_read(2, 1021, RLC(stack.cd_length, randomness)) \ + .stack_read(2, 1022, RLC(stack.rd_offset, randomness)) \ + .stack_read(2, 1023, RLC(stack.rd_length, randomness)) \ + .stack_write(2, 1023, RLC(is_success, randomness)) \ + .account_read(callee.address, AccountFieldTag.CodeHash, callee_bytecode_hash) \ + .tx_access_list_account_read(1, callee.address, is_warm_access) \ + .call_context_read(2, CallContextFieldTag.IsSuccess, 0) \ + .call_context_read(2, CallContextFieldTag.CallerId, 1) \ + .call_context_read(1, CallContextFieldTag.IsRoot, False) \ + .call_context_read(1, CallContextFieldTag.IsCreate, False) \ + .call_context_read(1, CallContextFieldTag.CodeHash, caller_bytecode_hash) \ + .call_context_read(1, CallContextFieldTag.ProgramCounter, program_counter + 1) \ + .call_context_read(1, CallContextFieldTag.StackPointer, 1023) \ + .call_context_read(1, CallContextFieldTag.GasLeft, caller_context.gas_left) \ + .call_context_read(1, CallContextFieldTag.MemorySize, caller_context.memory_size) \ + .call_context_read(1, CallContextFieldTag.ReversibleWriteCounter, caller_context.reversible_write_counter) \ + .call_context_write(1, CallContextFieldTag.LastCalleeId, 2) \ + .call_context_write(1, CallContextFieldTag.LastCalleeReturnDataOffset, 0) \ + .call_context_write(1, CallContextFieldTag.LastCalleeReturnDataLength, 0) \ + # fmt on tables = Tables( block_table=set(Block().table_assignments(randomness)), @@ -239,14 +276,14 @@ def test_oog_call_not_root(caller_ctx: CallerContext, callee: Account): tables=tables, steps=[ StepState( - execution_state=ExecutionState.ErrorOutOfGasCALL, + execution_state=ExecutionState.ErrorOutOfGasCall, rw_counter=24, call_id=2, is_root=False, is_create=False, code_hash=caller_bytecode_hash, - program_counter=231, - stack_pointer=1017, + program_counter=program_counter, + stack_pointer=1018 - has_value, gas_left=0, reversible_write_counter=callee_reversible_write_counter, ), @@ -254,14 +291,14 @@ def test_oog_call_not_root(caller_ctx: CallerContext, callee: Account): execution_state=ExecutionState.STOP, rw_counter=rw_dictionary.rw_counter, call_id=1, - is_root=caller_ctx.is_root, - is_create=caller_ctx.is_create, + is_root=False, + is_create=False, code_hash=caller_bytecode_hash, - program_counter=caller_ctx.program_counter, - stack_pointer=caller_ctx.stack_pointer, - gas_left=caller_ctx.gas_left, - memory_size=caller_ctx.memory_size, - reversible_write_counter=caller_ctx.reversible_write_counter + program_counter=program_counter + 1, + stack_pointer=1023, + gas_left=caller_context.gas_left, + memory_size=caller_context.memory_size, + reversible_write_counter=caller_context.reversible_write_counter + callee_reversible_write_counter, ), ],