diff --git a/specs/opcode/39CODECOPY.md b/specs/opcode/39CODECOPY.md new file mode 100644 index 000000000..8cac964d8 --- /dev/null +++ b/specs/opcode/39CODECOPY.md @@ -0,0 +1,43 @@ +# CODECOPY opcode + +## Procedure + +The `CODECOPY` opcode pops `memory_offset`, `code_offset` and `size` from the stack. +It then copies `size` bytes of code running in the current environment from an offset `code_offset` to the memory at the address `memory_offset`. For out-of-bound scenarios where `size > len(code) - code_offset`, EVM pads 0 to the end of the copied bytes. + +The gas cost of `CODECOPY` opcode consists of two parts: + +1. A constant gas cost: `3 gas` +2. A dynamic gas cost: cost of memory expansion and copying (variable depending on the `size` copied to memory) + +## Circuit Behaviour + +`CODECOPY` makes use of the internal execution step `CopyCodeToMemory` and loops over these steps iteratively until there are no more bytes to be copied. The `CODECOPY` circuit itself only constrains the values popped from stack and call context/account read lookups. + +The gadget then transits to the internal state of `CopyCodeToMemory`. + +## Constraints + +1. opId = 0x39 +2. State Transitions: + - rw_counter -> rw_counter + 3 (3 stack reads) + - stack_pointer -> stack_pointer + 3 + - pc -> pc + 1 + - gas -> 3 + dynamic_cost (memory expansion and copier cost when `size > 0`) + - memory_size + - `prev_memory_size` if `size = 0` + - `max(prev_memory_size, (memory_offset + size + 31) / 32)` if `size > 0` +3. Lookups: + - `memory_offset` is at the top of the stack + - `code_offset` is at the second position of the stack + - `size` is at the third position of the stack + - `code_size` from the bytecode table + +## Exceptions + +1. Stack Underflow: `1021 <= stack_pointer <= 1024` +2. Out-of-Gas: remaining gas is not enough + +## Code + +Please refer to `src/zkevm_specs/evm/execution/codecopy.py` diff --git a/specs/opcode/CopyCodeToMemory.md b/specs/opcode/CopyCodeToMemory.md new file mode 100644 index 000000000..c6b67abde --- /dev/null +++ b/specs/opcode/CopyCodeToMemory.md @@ -0,0 +1,31 @@ +# CopyCodeToMemory + +## Circuit Behaviour + +`CopyCodeToMemory` is an internal execution state and doesn't correspond to an EVM opcode. It verifies that data from bytecode table has been written to memory. This gadget can in one iteration only copy `MAX_COPY_BYTES` number of bytes, hence for lengths longer than the bound the gadget loops itself until there are no more bytes to be copied. + +The `CopyCodeToMemory` circuit uses the `BufferReaderGadget` to check if the access is out of bounds and needs 0 padding. + +The `CopyCodeToMemory` circuit looks up the bytes read from buffer against both the bytecode table and the RW table (memory-write). An additional constraint checks whether or not the copying is finished, and if not, it constrains the next execution state to continue being `CopyCodeToMemory` while also adding constraints to the next step's auxiliary data. + +## Constraints + +We define `n_bytes_read` as the number of bytes read from the bytecode table. `n_bytes_read <= MAX_COPY_BYTES`. + +We define `n_bytes_written` as the number of bytes written to the memory. `n_bytes_written <= MAX_COPY_BYTES`. + +`n_bytes_read` differs from `n_bytes_written` in out-of-bound cases where nothing is read from the bytecode table but a `0` is written to memory. + +1. State Transition: + - rw_counter: `n_bytes_written` +2. Lookups: + - `n_bytes_read` lookups from bytecode table + - `n_bytes_written` lookups from RW table (memory-write) + +## Exceptions + +No exceptions for `CopyCodeToMemory` since it is an internal state. + +## Code + +Please refer to `src/zkevm_specs/evm/execution/copy_code_to_memory.py`. diff --git a/src/zkevm_specs/bytecode.py b/src/zkevm_specs/bytecode.py index 1fee53213..83ced381d 100644 --- a/src/zkevm_specs/bytecode.py +++ b/src/zkevm_specs/bytecode.py @@ -1,13 +1,13 @@ from typing import Sequence, Union, Tuple, Set, NamedTuple from collections import namedtuple -from .util import keccak256, FQ, RLC -from .evm import get_push_size, BytecodeTableRow +from .util import keccak256, EMPTY_HASH, FQ, RLC +from .evm import get_push_size, BytecodeFieldTag, BytecodeTableRow from .encoding import U8, U256, is_circuit_code # Row in the circuit Row = namedtuple( "Row", - "q_first q_last hash index byte is_code push_data_left hash_rlc hash_length byte_push_size is_final padding", + "q_first q_last hash tag index is_code value push_data_left hash_rlc hash_length byte_push_size is_final padding", ) # Unrolled bytecode class UnrolledBytecode(NamedTuple): @@ -33,21 +33,28 @@ def select( def check_bytecode_row( row: Row, prev_row: Row, + next_row: Row, push_table: Set[Tuple[int, int]], keccak_table: Set[Tuple[int, int, int]], r: int, ): row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in row]) prev_row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in prev_row]) + next_row = Row(*[v if isinstance(v, RLC) else FQ(v) for v in next_row]) if row.q_first == 0 and prev_row.is_final == 0: # Continue - # index needs to increase by 1 - assert row.index == prev_row.index + 1 - # is_code := push_data_left_prev == 0 - assert row.is_code == (prev_row.push_data_left == 0) + if prev_row.tag == BytecodeFieldTag.Length: + # index starts from 0 + assert row.index == 0 + # is_code := 1, since this is the first byte of the bytecode + assert row.is_code == 1 + else: + # index is 1 more than previous row's index + assert row.index == prev_row.index + 1 + # is_code := push_data_left_prev == 0 + assert row.is_code == (prev_row.push_data_left == 0) # hash_rlc := hash_rlc_prev * r + byte - assert row.hash_rlc == prev_row.hash_rlc * r + row.byte - + assert row.hash_rlc == prev_row.hash_rlc * r + row.value # padding needs to remain the same assert row.padding == prev_row.padding # hash needs to remain the same @@ -56,21 +63,33 @@ def check_bytecode_row( assert row.hash_length == prev_row.hash_length else: # Start - # index needs to start at 0 - assert row.index == 0 - # is_code needs to be 1 (first byte is always an opcode) - assert row.is_code == True - # hash_rlc needs to start at byte - assert row.hash_rlc == row.byte + # the row following an `is_final` previous row is either tagged Length + if row.tag == BytecodeFieldTag.Length: + # value matches hash length + assert row.value == row.hash_length + # if bytecode length is zero + if row.value == 0: + # bytecode hash should be EMPTY_HASH + assert row.hash == RLC(EMPTY_HASH, FQ(r)).expr() + # the next row should be a tag Length or padding + assert (next_row.tag == BytecodeFieldTag.Length) or (next_row.tag == 0) + else: + # the next row should be tag Byte + assert next_row.tag == BytecodeFieldTag.Byte + # or is the start of padding rows + else: + assert row.padding == 1 # is_final needs to be boolean assert_bool(row.is_final) # padding needs to be boolean assert_bool(row.padding) - # push_data_left := is_code ? byte_push_size : push_data_left_prev - 1 - assert row.push_data_left == select( - row.is_code, row.byte_push_size, prev_row.push_data_left - 1 - ) + + if row.tag == BytecodeFieldTag.Byte: + # push_data_left := is_code ? byte_push_size : push_data_left_prev - 1 + assert row.push_data_left == select( + row.is_code, row.byte_push_size, prev_row.push_data_left - 1 + ) # Padding if row.q_first == 0: @@ -87,9 +106,10 @@ def check_bytecode_row( # the last row needs to be the last byte assert row.padding == 1 or row.is_final == 1 - # Lookup how many bytes the current opcode pushes - # (also indirectly range checks `byte` to be in [0, 255]) - assert (row.byte, row.byte_push_size) in push_table + if row.tag == BytecodeFieldTag.Byte: + # Lookup how many bytes the current opcode pushes + # (also indirectly range checks `byte` to be in [0, 255]) + assert (row.value, row.byte_push_size) in push_table # keccak lookup when on the last byte if row.is_final == 1 and row.padding == 0: @@ -107,13 +127,15 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando push_data_left = 0 hash_rlc = FQ(0) for idx, row in enumerate(bytecode.rows): + # Subsequent rows represent the bytecode bytes # Track which byte is an opcode and which is push data is_code = push_data_left == 0 - byte_push_size = get_push_size(row.byte) - push_data_left = byte_push_size if is_code else push_data_left - 1 - - # Add the byte to the accumulator - hash_rlc = hash_rlc * randomness + row.byte + byte_push_size = 0 + if idx > 0: + byte_push_size = get_push_size(row.value) + push_data_left = byte_push_size if is_code else push_data_left - 1 + # Add the byte to the accumulator + hash_rlc = hash_rlc * randomness + row.value # Set the data for this row rows.append( @@ -121,14 +143,16 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando offset == 0, offset == last_row_offset, row.bytecode_hash, + row.field_tag, row.index, - row.byte, row.is_code, + row.value, push_data_left, hash_rlc, len(bytecode.bytes), byte_push_size, - idx == len(bytecode.bytes) - 1, + # Since 1 row is taken up by the Length tag + idx == len(bytecode.bytes), False, ) ) @@ -147,6 +171,7 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rando 0, 0, 0, + 0, True, 0, 0, diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index a9c9fed36..699c949fc 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -3,6 +3,7 @@ from ..execution_state import ExecutionState from .begin_tx import * +from .copy_code_to_memory import * from .end_tx import * from .end_block import * from .memory_copy import * @@ -19,6 +20,7 @@ from .callvalue import * from .calldatacopy import * from .calldataload import * +from .codecopy import * from .gas import * from .iszero import * from .jump import * @@ -39,6 +41,7 @@ ExecutionState.EndTx: end_tx, ExecutionState.EndBlock: end_block, ExecutionState.CopyToMemory: copy_to_memory, + ExecutionState.CopyCodeToMemory: copy_code_to_memory, ExecutionState.ADD: add, ExecutionState.ORIGIN: origin, ExecutionState.CALLER: caller, @@ -46,6 +49,7 @@ ExecutionState.CALLDATACOPY: calldatacopy, ExecutionState.CALLDATALOAD: calldataload, ExecutionState.CALLDATASIZE: calldatasize, + ExecutionState.CODECOPY: codecopy, ExecutionState.COINBASE: coinbase, ExecutionState.TIMESTAMP: timestamp, ExecutionState.NUMBER: number, diff --git a/src/zkevm_specs/evm/execution/calldatacopy.py b/src/zkevm_specs/evm/execution/calldatacopy.py index 925501515..1b66b9c95 100644 --- a/src/zkevm_specs/evm/execution/calldatacopy.py +++ b/src/zkevm_specs/evm/execution/calldatacopy.py @@ -33,7 +33,7 @@ def calldatacopy(instruction: Instruction): gas_cost = instruction.memory_copier_gas_cost(length, memory_expansion_gas_cost) # When length != 0, constrain the state in the next execution state CopyToMemory - if not instruction.is_zero(length): + if instruction.is_zero(length) == FQ(0): assert instruction.next is not None instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) next_aux = instruction.next.aux_data diff --git a/src/zkevm_specs/evm/execution/codecopy.py b/src/zkevm_specs/evm/execution/codecopy.py new file mode 100644 index 000000000..bea246a0b --- /dev/null +++ b/src/zkevm_specs/evm/execution/codecopy.py @@ -0,0 +1,47 @@ +from ...util import N_BYTES_MEMORY_ADDRESS, FQ +from ..execution_state import ExecutionState +from ..instruction import Instruction, Transition +from ..step import CopyCodeToMemoryAuxData +from ..table import RW, RWTableTag, CallContextFieldTag, AccountFieldTag + + +def codecopy(instruction: Instruction): + opcode = instruction.opcode_lookup(True) + + memory_offset_word, code_offset_word, size_word = ( + instruction.stack_pop(), + instruction.stack_pop(), + instruction.stack_pop(), + ) + + memory_offset, size = instruction.memory_offset_and_length(memory_offset_word, size_word) + code_offset = instruction.rlc_to_fq_exact(code_offset_word, N_BYTES_MEMORY_ADDRESS) + + code_size = instruction.bytecode_length(instruction.curr.code_source) + + next_memory_size, memory_expansion_gas_cost = instruction.memory_expansion_dynamic_length( + memory_offset, size + ) + gas_cost = instruction.memory_copier_gas_cost(size, memory_expansion_gas_cost) + + if instruction.is_zero(size) == FQ(0): + assert instruction.next is not None + instruction.constrain_equal( + instruction.next.execution_state, ExecutionState.CopyCodeToMemory + ) + next_aux = instruction.next.aux_data + assert isinstance(next_aux, CopyCodeToMemoryAuxData) + instruction.constrain_equal(next_aux.src_addr, code_offset) + instruction.constrain_equal(next_aux.dst_addr, memory_offset) + instruction.constrain_equal(next_aux.src_addr_end, code_size) + instruction.constrain_equal(next_aux.bytes_left, size) + instruction.constrain_equal(next_aux.code_source, instruction.curr.code_source) + + instruction.step_state_transition_in_same_context( + opcode, + rw_counter=Transition.delta(instruction.rw_counter_offset), + program_counter=Transition.delta(1), + stack_pointer=Transition.delta(3), + memory_size=Transition.to(next_memory_size), + dynamic_gas_cost=gas_cost, + ) diff --git a/src/zkevm_specs/evm/execution/copy_code_to_memory.py b/src/zkevm_specs/evm/execution/copy_code_to_memory.py new file mode 100644 index 000000000..2cec47410 --- /dev/null +++ b/src/zkevm_specs/evm/execution/copy_code_to_memory.py @@ -0,0 +1,62 @@ +import itertools +from typing import Iterator + +from ...util import FQ, MAX_N_BYTES_COPY_CODE_TO_MEMORY, N_BYTES_MEMORY_SIZE, RLC +from ..execution_state import ExecutionState +from ..instruction import Instruction, Transition +from ..step import CopyCodeToMemoryAuxData +from ..table import RW +from ..util import BufferReaderGadget + + +def copy_code_to_memory(instruction: Instruction): + aux = instruction.curr.aux_data + assert isinstance(aux, CopyCodeToMemoryAuxData) + + buffer_reader = BufferReaderGadget( + instruction, MAX_N_BYTES_COPY_CODE_TO_MEMORY, aux.src_addr, aux.src_addr_end, aux.bytes_left + ) + + for idx in range(MAX_N_BYTES_COPY_CODE_TO_MEMORY): + if buffer_reader.read_flag(idx) == 1: + byte = instruction.bytecode_lookup( + aux.code_source, + aux.src_addr + idx, + ) + buffer_reader.constrain_byte(idx, byte) + + for idx in range(MAX_N_BYTES_COPY_CODE_TO_MEMORY): + if buffer_reader.has_data(idx) == 1: + byte = instruction.memory_lookup(RW.Write, aux.dst_addr + idx) + buffer_reader.constrain_byte(idx, byte) + + copied_bytes = buffer_reader.num_bytes() + lt, finished = instruction.compare(copied_bytes, aux.bytes_left, N_BYTES_MEMORY_SIZE) + + # either copied bytes are less than the bytes left, or copying is finished + instruction.constrain_zero((1 - lt) * (1 - finished)) + + if finished == 0: + assert instruction.next is not None + instruction.constrain_equal( + instruction.next.execution_state, ExecutionState.CopyCodeToMemory + ) + next_aux = instruction.next.aux_data + assert next_aux is not None and isinstance(next_aux, CopyCodeToMemoryAuxData) + instruction.constrain_equal(next_aux.src_addr, aux.src_addr + copied_bytes) + instruction.constrain_equal(next_aux.dst_addr, aux.dst_addr + copied_bytes) + instruction.constrain_equal(next_aux.bytes_left + copied_bytes, aux.bytes_left) + instruction.constrain_equal(next_aux.src_addr_end, aux.src_addr_end) + instruction.constrain_equal(next_aux.code_source, aux.code_source) + + instruction.constrain_step_state_transition( + rw_counter=Transition.delta(instruction.rw_counter_offset), + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_source=Transition.same(), + program_counter=Transition.same(), + stack_pointer=Transition.same(), + memory_size=Transition.same(), + state_write_counter=Transition.same(), + ) diff --git a/src/zkevm_specs/evm/execution/memory_copy.py b/src/zkevm_specs/evm/execution/memory_copy.py index 9519145cf..7cac7f618 100644 --- a/src/zkevm_specs/evm/execution/memory_copy.py +++ b/src/zkevm_specs/evm/execution/memory_copy.py @@ -1,10 +1,9 @@ -from ...util import N_BYTES_MEMORY_SIZE, FQ, Expression +from ...util import MAX_N_BYTES_COPY_TO_MEMORY, N_BYTES_MEMORY_SIZE, FQ, Expression from ..execution_state import ExecutionState from ..instruction import Instruction, Transition from ..step import CopyToMemoryAuxData from ..table import RW from ..util import BufferReaderGadget -from ...util import MAX_COPY_BYTES def copy_to_memory(instruction: Instruction): @@ -12,10 +11,10 @@ def copy_to_memory(instruction: Instruction): assert isinstance(aux, CopyToMemoryAuxData) buffer_reader = BufferReaderGadget( - instruction, MAX_COPY_BYTES, aux.src_addr, aux.src_addr_end, aux.bytes_left + instruction, MAX_N_BYTES_COPY_TO_MEMORY, aux.src_addr, aux.src_addr_end, aux.bytes_left ) - for i in range(MAX_COPY_BYTES): + for i in range(MAX_N_BYTES_COPY_TO_MEMORY): if buffer_reader.read_flag(i) == 0: byte: Expression = FQ(0) elif aux.from_tx == 1: @@ -47,4 +46,12 @@ def copy_to_memory(instruction: Instruction): instruction.constrain_step_state_transition( rw_counter=Transition.delta(instruction.rw_counter_offset), + call_id=Transition.same(), + is_root=Transition.same(), + is_create=Transition.same(), + code_source=Transition.same(), + program_counter=Transition.same(), + stack_pointer=Transition.same(), + memory_size=Transition.same(), + state_write_counter=Transition.same(), ) diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index 3ba0e4f14..4b4930c03 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -21,6 +21,7 @@ class ExecutionState(IntEnum): EndBlock = auto() CopyToMemory = auto() CopyToLog = auto() + CopyCodeToMemory = auto() # Opcode's successful cases STOP = auto() diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index ea6b21b5b..84b694983 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -23,6 +23,7 @@ from .table import ( AccountFieldTag, BlockContextFieldTag, + BytecodeFieldTag, CallContextFieldTag, FixedTableRow, RWTableRow, @@ -395,9 +396,16 @@ def tx_log_lookup(self, field_tag: TxLogFieldTag, index: int = 0) -> Expression: return value def bytecode_lookup( - self, bytecode_hash: Expression, index: Expression, is_code: bool + self, bytecode_hash: Expression, index: Expression, is_code: Expression = None ) -> Expression: - return self.tables.bytecode_lookup(bytecode_hash, index, FQ(is_code)).byte + return self.tables.bytecode_lookup( + bytecode_hash, FQ(BytecodeFieldTag.Byte), index, is_code + ).value + + def bytecode_length(self, bytecode_hash: Expression) -> Expression: + return self.tables.bytecode_lookup( + bytecode_hash, FQ(BytecodeFieldTag.Length), FQ(0), FQ(0) + ).value def tx_gas_price(self, tx_id: Expression) -> RLC: return cast_expr(self.tx_context_lookup(tx_id, TxContextFieldTag.GasPrice), RLC) @@ -416,7 +424,7 @@ def opcode_lookup_at(self, index: FQ, is_code: bool) -> FQ: "The opcode source when is_root and is_create (root creation call) is not determined yet" ) else: - return self.bytecode_lookup(self.curr.code_source, index, is_code).expr() + return self.bytecode_lookup(self.curr.code_source, index, FQ(is_code)).expr() def rw_lookup( self, diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index c524fed71..f6dbb4faf 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -1,5 +1,6 @@ from typing import Any from .execution_state import ExecutionState +from .typing import Bytecode from ..util import FQ, RLC @@ -120,3 +121,25 @@ def __init__( self.bytes_left = FQ(bytes_left) self.src_addr_end = FQ(src_addr_end) self.is_persistent = FQ(is_persistent) + + +class CopyCodeToMemoryAuxData: + src_addr: FQ + dst_addr: FQ + bytes_left: FQ + src_addr_end: FQ + code_source: RLC + + def __init__( + self, + src_addr: int, + dst_addr: int, + bytes_left: int, + src_addr_end: int, + code_source: RLC, + ): + self.src_addr = FQ(src_addr) + self.dst_addr = FQ(dst_addr) + self.bytes_left = FQ(bytes_left) + self.src_addr_end = FQ(src_addr_end) + self.code_source = code_source diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 22cd620bb..2a60dc3b9 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -129,6 +129,15 @@ class TxContextFieldTag(IntEnum): CallData = auto() +class BytecodeFieldTag(IntEnum): + """ + Tag for BytecodeTable lookup. + """ + + Length = 1 + Byte = 2 + + class RW(IntEnum): Read = 0 Write = 1 @@ -293,9 +302,10 @@ class TxTableRow(TableRow): @dataclass(frozen=True) class BytecodeTableRow(TableRow): bytecode_hash: Expression + field_tag: Expression index: Expression - byte: Expression is_code: Expression + value: Expression @dataclass(frozen=True) @@ -373,10 +383,15 @@ def tx_lookup( return _lookup(TxTableRow, self.tx_table, query) def bytecode_lookup( - self, bytecode_hash: Expression, index: Expression, is_code: Expression + self, + bytecode_hash: Expression, + field_tag: Expression, + index: Expression, + is_code: Expression = None, ) -> BytecodeTableRow: query = { "bytecode_hash": bytecode_hash, + "field_tag": field_tag, "index": index, "is_code": is_code, } diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index eda5e814c..274a779ba 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -21,6 +21,7 @@ AccountFieldTag, BlockContextFieldTag, BlockTableRow, + BytecodeFieldTag, BytecodeTableRow, CallContextFieldTag, RWTableRow, @@ -256,18 +257,25 @@ def __iter__(self): return self def __next__(self): - if self.idx == len(self.code): + # return the length of the bytecode in the first row + if self.idx == 0: + self.idx += 1 + return BytecodeTableRow( + self.hash, FQ(BytecodeFieldTag.Length), FQ(0), FQ(0), FQ(len(self.code)) + ) + + if self.idx > len(self.code): raise StopIteration - idx = self.idx + # the other rows represent each byte in the bytecode + idx = self.idx - 1 byte = self.code[idx] - is_code = self.push_data_left == 0 self.push_data_left = get_push_size(byte) if is_code else self.push_data_left - 1 - self.idx += 1 - - return BytecodeTableRow(self.hash, FQ(idx), FQ(byte), FQ(is_code)) + return BytecodeTableRow( + self.hash, FQ(BytecodeFieldTag.Byte), FQ(idx), FQ(is_code), FQ(byte) + ) return BytecodeIterator(RLC(self.hash(), randomness).expr(), self.code) diff --git a/src/zkevm_specs/util/param.py b/src/zkevm_specs/util/param.py index ad7ecbb05..067bd3fd7 100644 --- a/src/zkevm_specs/util/param.py +++ b/src/zkevm_specs/util/param.py @@ -71,6 +71,12 @@ # Coefficient of linear part of memory expansion gas cost MEMORY_EXPANSION_LINEAR_COEFF = 3 +# Maximum number of bytes copied during one single iteration of CopyToMemory, i.e. the internal state used by the +# CALLDATACOPY gadget +MAX_N_BYTES_COPY_TO_MEMORY = 32 +# Maximum number of bytes copied during one single iteration of CopyCodeToMemory, i.e. the internal state used by +# the CODECOPY gadget +MAX_N_BYTES_COPY_CODE_TO_MEMORY = 32 COLD_SLOAD_COST = 2100 WARM_STORAGE_READ_COST = 100 diff --git a/tests/evm/test_calldatacopy.py b/tests/evm/test_calldatacopy.py index 471143ebf..f9b1e31be 100644 --- a/tests/evm/test_calldatacopy.py +++ b/tests/evm/test_calldatacopy.py @@ -17,11 +17,11 @@ Bytecode, RWDictionary, ) -from zkevm_specs.evm.execution.memory_copy import MAX_COPY_BYTES from zkevm_specs.util import ( rand_fq, rand_bytes, GAS_COST_COPY, + MAX_N_BYTES_COPY_TO_MEMORY, MEMORY_EXPANSION_QUAD_DENOMINATOR, MEMORY_EXPANSION_LINEAR_COEFF, ) @@ -63,7 +63,7 @@ def make_copy_step( memory_size: int, gas_left: int, code_source: RLC, -) -> Tuple[StepState, Sequence[RW]]: +) -> StepState: aux_data = CopyToMemoryAuxData( src_addr=src_addr, dst_addr=dst_addr, @@ -85,7 +85,7 @@ def make_copy_step( aux_data=aux_data, ) - num_bytes = min(MAX_COPY_BYTES, bytes_left) + num_bytes = min(MAX_N_BYTES_COPY_TO_MEMORY, bytes_left) for i in range(num_bytes): byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 if not from_tx and src_addr + i < src_addr_end: @@ -129,9 +129,9 @@ def make_copy_steps( code_source, ) steps.append(new_step) - src_addr += MAX_COPY_BYTES - dst_addr += MAX_COPY_BYTES - bytes_left -= MAX_COPY_BYTES + src_addr += MAX_N_BYTES_COPY_TO_MEMORY + dst_addr += MAX_N_BYTES_COPY_TO_MEMORY + bytes_left -= MAX_N_BYTES_COPY_TO_MEMORY return steps diff --git a/tests/evm/test_codecopy.py b/tests/evm/test_codecopy.py new file mode 100644 index 000000000..9492811d0 --- /dev/null +++ b/tests/evm/test_codecopy.py @@ -0,0 +1,323 @@ +from itertools import chain +import pytest +from typing import Mapping, Sequence, Tuple + +from zkevm_specs.evm import ( + AccountFieldTag, + Bytecode, + CallContextFieldTag, + CopyCodeToMemoryAuxData, + ExecutionState, + Opcode, + RW, + RWDictionary, + RWTableTag, + StepState, + Tables, + verify_steps, +) +from zkevm_specs.util import ( + GAS_COST_COPY, + FQ, + MAX_N_BYTES_COPY_CODE_TO_MEMORY, + MEMORY_EXPANSION_LINEAR_COEFF, + MEMORY_EXPANSION_QUAD_DENOMINATOR, + RLC, + U64, + rand_address, + rand_fq, +) + + +CALL_ID = 1 +TESTING_DATA = ( + # single step + (0x00, 0x00, 54), + # multi step + (0x00, 0x40, 123), + # out of bounds + (0x10, 0x20, 200), +) + + +def to_word_size(addr: int) -> int: + return (addr + 31) // 32 + + +def memory_gas_cost(memory_word_size: int) -> int: + quad_cost = memory_word_size * memory_word_size // MEMORY_EXPANSION_QUAD_DENOMINATOR + linear_cost = memory_word_size * MEMORY_EXPANSION_LINEAR_COEFF + return quad_cost + linear_cost + + +def memory_copier_gas_cost( + curr_memory_word_size: int, next_memory_word_size: int, length: int +) -> int: + curr_memory_cost = memory_gas_cost(curr_memory_word_size) + next_memory_cost = memory_gas_cost(next_memory_word_size) + return to_word_size(length) * GAS_COST_COPY + next_memory_cost - curr_memory_cost + + +def make_copy_code_step( + code: Bytecode, + code_source: RLC, + buffer_map: Mapping[int, int], + src_addr: int, + dst_addr: int, + src_addr_end: int, + bytes_left: int, + rw_dictionary: RWDictionary, + program_counter: int, + stack_pointer: int, + memory_size: int, + randomness: FQ, +) -> StepState: + aux_data = CopyCodeToMemoryAuxData( + src_addr=src_addr, + dst_addr=dst_addr, + src_addr_end=src_addr_end, + bytes_left=bytes_left, + code_source=RLC(code.hash(), randomness), + ) + step = StepState( + execution_state=ExecutionState.CopyCodeToMemory, + rw_counter=rw_dictionary.rw_counter, + call_id=CALL_ID, + is_root=True, + program_counter=program_counter, + stack_pointer=stack_pointer, + gas_left=0, + memory_size=memory_size, + code_source=code_source, + aux_data=aux_data, + ) + + num_bytes = min(MAX_N_BYTES_COPY_CODE_TO_MEMORY, bytes_left) + for i in range(num_bytes): + byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 + rw_dictionary.memory_write(CALL_ID, dst_addr + i, byte) + return step + + +def make_copy_code_steps( + code: Bytecode, + code_source: RLC, + src_addr: int, + dst_addr: int, + length: int, + rw_dictionary: RWDictionary, + program_counter: int, + stack_pointer: int, + memory_size: int, + randomness: FQ, +) -> Sequence[StepState]: + buffer_map = dict(zip(range(src_addr, len(code.code)), code.code)) + steps = [] + bytes_left = length + while bytes_left > 0: + new_step = make_copy_code_step( + code, + code_source, + buffer_map, + src_addr, + dst_addr, + len(code.code), + bytes_left, + rw_dictionary, + program_counter, + stack_pointer, + memory_size, + randomness, + ) + steps.append(new_step) + src_addr += MAX_N_BYTES_COPY_CODE_TO_MEMORY + dst_addr += MAX_N_BYTES_COPY_CODE_TO_MEMORY + bytes_left -= MAX_N_BYTES_COPY_CODE_TO_MEMORY + return steps + + +@pytest.mark.parametrize("src_addr, dst_addr, length", TESTING_DATA) +def test_codecopy(src_addr: U64, dst_addr: U64, length: U64): + randomness = rand_fq() + + length_rlc = RLC(length, randomness) + src_addr_rlc = RLC(src_addr, randomness) + dst_addr_rlc = RLC(dst_addr, randomness) + + code = Bytecode().push32(length_rlc).push32(src_addr_rlc).push32(dst_addr_rlc).codecopy().stop() + + code_source = RLC(code.hash(), randomness) + next_memory_word_size = to_word_size(dst_addr + length) + + gas_cost_push32 = Opcode.PUSH32.constant_gas_cost() + gas_cost_codecopy = Opcode.CODECOPY.constant_gas_cost() + memory_copier_gas_cost( + 0, next_memory_word_size, length + ) + total_gas_cost = gas_cost_codecopy + (3 * gas_cost_push32) + + rw_dictionary = ( + RWDictionary(1) + .stack_write(CALL_ID, 1023, length_rlc) + .stack_write(CALL_ID, 1022, src_addr_rlc) + .stack_write(CALL_ID, 1021, dst_addr_rlc) + .stack_read(CALL_ID, 1021, dst_addr_rlc) + .stack_read(CALL_ID, 1022, src_addr_rlc) + .stack_read(CALL_ID, 1023, length_rlc) + ) + # rw counter before memory writes + rw_counter_interim = rw_dictionary.rw_counter + + steps = [ + StepState( + execution_state=ExecutionState.PUSH, + rw_counter=1, + call_id=CALL_ID, + is_root=True, + code_source=code_source, + program_counter=0, + stack_pointer=1024, + gas_left=total_gas_cost, + ), + StepState( + execution_state=ExecutionState.PUSH, + rw_counter=2, + call_id=CALL_ID, + is_root=True, + code_source=code_source, + program_counter=33, + stack_pointer=1023, + gas_left=total_gas_cost - gas_cost_push32, + ), + StepState( + execution_state=ExecutionState.PUSH, + rw_counter=3, + call_id=CALL_ID, + is_root=True, + code_source=code_source, + program_counter=66, + stack_pointer=1022, + gas_left=total_gas_cost - 2 * gas_cost_push32, + ), + StepState( + execution_state=ExecutionState.CODECOPY, + rw_counter=4, + call_id=CALL_ID, + is_root=True, + code_source=code_source, + program_counter=99, + stack_pointer=1021, + gas_left=gas_cost_codecopy, + ), + ] + + steps_internal = make_copy_code_steps( + code, + code_source, + src_addr, + dst_addr, + length, + rw_dictionary=rw_dictionary, + program_counter=100, + stack_pointer=1024, + memory_size=next_memory_word_size, + randomness=randomness, + ) + steps.extend(steps_internal) + + # rw counter post memory writes + rw_counter_final = rw_dictionary.rw_counter + assert rw_counter_final - rw_counter_interim == length + + steps.append( + StepState( + execution_state=ExecutionState.STOP, + rw_counter=rw_dictionary.rw_counter, + call_id=CALL_ID, + is_root=True, + code_source=code_source, + program_counter=100, + stack_pointer=1024, + memory_size=next_memory_word_size, + gas_left=0, + ) + ) + + tables = Tables( + block_table=set(), + tx_table=set(), + bytecode_table=set(code.table_assignments(randomness)), + rw_table=set(rw_dictionary.rws), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=steps, + ) + + +@pytest.mark.parametrize("src_addr, dst_addr, length", TESTING_DATA) +def test_copy_code_to_memory(src_addr: U64, dst_addr: U64, length: U64): + randomness = rand_fq() + + code = ( + Bytecode() + .push32(0x123) + .pop() + .push32(0x213) + .pop() + .push32(0x321) + .pop() + .push32(0x12349AB) + .pop() + .push32(0x1928835) + .pop() + ) + + dummy_code = Bytecode().stop() + code_source = RLC(dummy_code.hash(), randomness) + + rw_dictionary = RWDictionary(1) + + next_memory_word_size = to_word_size(dst_addr + length) + steps = make_copy_code_steps( + code, + code_source, + src_addr, + dst_addr, + length, + rw_dictionary=rw_dictionary, + program_counter=0, + memory_size=next_memory_word_size, + stack_pointer=1024, + randomness=randomness, + ) + steps.append( + StepState( + execution_state=ExecutionState.STOP, + rw_counter=rw_dictionary.rw_counter, + call_id=CALL_ID, + is_root=True, + is_create=False, + code_source=code_source, + program_counter=0, + stack_pointer=1024, + memory_size=next_memory_word_size, + gas_left=0, + ) + ) + + tables = Tables( + block_table=set(), + tx_table=set(), + bytecode_table=set(code.table_assignments(randomness)).union( + dummy_code.table_assignments(randomness) + ), + rw_table=set(rw_dictionary.rws), + ) + + verify_steps( + randomness=randomness, + tables=tables, + steps=steps, + ) diff --git a/tests/test_bytecode_circuit.py b/tests/test_bytecode_circuit.py index 270c56e24..91eea29fc 100644 --- a/tests/test_bytecode_circuit.py +++ b/tests/test_bytecode_circuit.py @@ -2,7 +2,7 @@ from copy import deepcopy from zkevm_specs.bytecode import * -from zkevm_specs.evm import Opcode, Bytecode, BytecodeTableRow, is_push +from zkevm_specs.evm import Opcode, Bytecode, BytecodeFieldTag, BytecodeTableRow, is_push from zkevm_specs.util import RLC, rand_fq # Unroll the bytecode @@ -18,7 +18,8 @@ def verify(k, bytecodes, randomness, success): try: for (idx, row) in enumerate(rows): prev_row = rows[(idx - 1) % len(rows)] - check_bytecode_row(row, prev_row, push_table, keccak_table, randomness) + next_row = rows[(idx + 1) % len(rows)] + check_bytecode_row(row, prev_row, next_row, push_table, keccak_table, randomness) ok = True except AssertionError as e: if success: @@ -38,19 +39,21 @@ def test_bytecode_unrolling(): for byte in range(256): if not is_push(byte): bytecode.append(byte) - rows.append((0, len(rows), byte, True)) + rows.append((0, BytecodeFieldTag.Byte, len(rows), True, byte)) # Now add the different push ops for n in range(1, 33): data_byte = int(Opcode.PUSH32) bytecode.append(Opcode.PUSH1 + n - 1) bytecode.extend([data_byte] * n) - rows.append((0, len(rows), Opcode.PUSH1 + n - 1, True)) + rows.append((0, BytecodeFieldTag.Byte, len(rows), True, Opcode.PUSH1 + n - 1)) for _ in range(n): - rows.append((0, len(rows), data_byte, False)) + rows.append((0, BytecodeFieldTag.Byte, len(rows), False, data_byte)) # Set the hash of the complete bytecode in the rows hash = RLC(bytes(reversed(keccak256(bytes(bytecode)))), randomness) for i in range(len(rows)): - rows[i] = BytecodeTableRow(hash.expr(), rows[i][1], rows[i][2], rows[i][3]) + rows[i] = BytecodeTableRow(hash.expr(), rows[i][1], rows[i][2], rows[i][3], rows[i][4]) + # Prepend the length of bytecode to rows + rows.insert(0, BytecodeTableRow(hash.expr(), BytecodeFieldTag.Length, 0, 0, len(bytecode))) # Unroll the bytecode unrolled = unroll(bytes(bytecode), randomness) # Check if the bytecode was unrolled correctly @@ -65,7 +68,7 @@ def test_bytecode_empty(): def test_bytecode_full(): - bytecodes = [unroll(bytes([7] * 2**k), randomness)] + bytecodes = [unroll(bytes([7] * (2**k - 1)), randomness)] verify(k, bytecodes, randomness, True) @@ -89,22 +92,34 @@ def test_bytecode_invalid_hash_data(): unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) verify(k, [unrolled], randomness, True) - # Change the hash on the first position + # Change the hash on the first row, i.e. row denoting tag Length invalid = deepcopy(unrolled) row = unrolled.rows[0] - invalid.rows[0] = BytecodeTableRow(row.bytecode_hash + 1, row.index, row.byte, row.is_code) + invalid.rows[0] = BytecodeTableRow( + row.bytecode_hash + 1, row.field_tag, row.index, row.is_code, row.value + ) + verify(k, [invalid], randomness, False) + + # Change the hash on the second row, i.e. first row with tag Byte + invalid = deepcopy(unrolled) + row = unrolled.rows[1] + invalid.rows[1] = BytecodeTableRow( + row.bytecode_hash + 1, row.field_tag, row.index, row.is_code, row.value + ) verify(k, [invalid], randomness, False) # Change the hash on another position invalid = deepcopy(unrolled) row = unrolled.rows[4] - invalid.rows[0] = BytecodeTableRow(row.bytecode_hash + 1, row.index, row.byte, row.is_code) + invalid.rows[1] = BytecodeTableRow( + row.bytecode_hash + 1, row.field_tag, row.index, row.is_code, row.value + ) verify(k, [invalid], randomness, False) # Change all the hashes so it doesn't match the keccak lookup hash invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): - invalid.rows[idx] = BytecodeTableRow(1, row.index, row.byte, row.is_code) + invalid.rows[idx] = BytecodeTableRow(1, row.field_tag, row.index, row.is_code, row.value) verify(k, [invalid], randomness, False) @@ -116,14 +131,14 @@ def test_bytecode_invalid_index(): invalid = deepcopy(unrolled) for idx, row in enumerate(unrolled.rows): invalid.rows[idx] = BytecodeTableRow( - row.bytecode_hash + 1, row.index, row.byte, row.is_code + row.bytecode_hash + 1, row.field_tag, row.index, row.is_code, row.value ) verify(k, [invalid], randomness, False) # Don't increment an index once invalid = deepcopy(unrolled) invalid.rows[-1] = BytecodeTableRow( - invalid.rows[-1].bytecode_hash - 1, row.index, row.byte, row.is_code + invalid.rows[-1].bytecode_hash - 1, row.field_tag, row.index, row.is_code, row.value ) verify(k, [invalid], randomness, False) @@ -132,22 +147,24 @@ def test_bytecode_invalid_byte_data(): unrolled = unroll(bytes([8, 2, 3, 8, 9, 7, 128]), randomness) verify(k, [unrolled], randomness, True) - # Change the first byte + # Change the first byte in the bytecode invalid = deepcopy(unrolled) - row = unrolled.rows[0] - invalid.rows[0] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 9) + row = unrolled.rows[1] + invalid.rows[1] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, row.is_code, 9) verify(k, [invalid], randomness, False) # Change a byte on another position invalid = deepcopy(unrolled) row = unrolled.rows[5] - invalid.rows[5] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 6) + invalid.rows[5] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, row.is_code, 6) verify(k, [invalid], randomness, False) # Set a byte value out of range invalid = deepcopy(unrolled) row = unrolled.rows[3] - invalid.rows[3] = BytecodeTableRow(row.bytecode_hash, row.index, row.byte, 256) + invalid.rows[3] = BytecodeTableRow( + row.bytecode_hash, row.field_tag, row.index, row.is_code, 256 + ) verify(k, [invalid], randomness, False) @@ -168,20 +185,21 @@ def test_bytecode_invalid_is_code(): ) verify(k, [unrolled], randomness, True) + # The first row, i.e. index == 0 is taken up by the tag Length. # Mark the 3rd byte as code (is push data from the first PUSH1) invalid = deepcopy(unrolled) - row = unrolled.rows[2] - invalid.rows[2] = BytecodeTableRow(row.bytecode_hash, row.index, 1, row.is_code) + row = unrolled.rows[3] + invalid.rows[3] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, 1, row.value) verify(k, [invalid], randomness, False) # Mark the 4rd byte as data (is code) invalid = deepcopy(unrolled) - row = unrolled.rows[3] - invalid.rows[3] = BytecodeTableRow(row.bytecode_hash, row.index, 0, row.is_code) + row = unrolled.rows[4] + invalid.rows[4] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, 0, row.value) verify(k, [invalid], randomness, False) # Mark the 7th byte as code (is data for the PUSH7) invalid = deepcopy(unrolled) - row = unrolled.rows[6] - invalid.rows[6] = BytecodeTableRow(row.bytecode_hash, row.index, 1, row.is_code) + row = unrolled.rows[7] + invalid.rows[7] = BytecodeTableRow(row.bytecode_hash, row.field_tag, row.index, 1, row.value) verify(k, [invalid], randomness, False)