diff --git a/specs/copy-proof.md b/specs/copy-proof.md new file mode 100644 index 000000000..03d40303f --- /dev/null +++ b/specs/copy-proof.md @@ -0,0 +1,57 @@ +# Copy Proof + +The copy proof checks the values in the copy table and applies the lookup arguments to the corresponding tables to check if the value read from and write to data source is correct. +It also checks the padding behavior that the value read from an out-of-boundary address is 0. + +## Circuit Layout + +First, copy circuit contains 13 columns from the [copy table](./tables.md#copytable) with the same witness assignment. +Every two rows in the copy circuit represent a copy step where the first row is a read operation and the second is a write operation. +A copy event consists of multiple copy steps, which the first row in the copy event has `is_first` assigned to 1 and the last row has `is_last` assigned to 1. + +In addition to the columns in the copy table, copy circuit adds a few auxiliary columns to help check the constraints. + +- `is_memory`: indicates if `Type` is `Memory` using `IsZero` gadget. +- `is_bytecode`: indicates if `Type` is `Bytecode` using `IsZero` gadget. +- `is_tx_calldata`: indicates if `Type` is `TxCalldata` using `IsZero` gadget. +- `is_tx_log`: indicates if `Type` is `TxLog` using `IsZero` gadget. + +## Circuit Constraints + +The constraints are divided into three groups. + +First, the circuit adds common constraints that applied to every rows in the circuit: + +- Boolean check for `is_first`, and `is_last` +- Check `is_first == 0` when `q_step == 0` +- Check `is_last == 0` when `q_step == 1` +- Construct the IsZero gadget and constrain `is_memory`, `is_bytecode`, `is_tx_calldata`, and `is_tx_log`. +- The transition constraints from a copy step to the next step (with 2-row rotation), applied to all rows except the last two rows (the last step) in a copy event: + - `ID`, `Type`, `AddressEnd` should be same between two steps. + - `Address` increase by 1 in the next copy step. +- The transition constraints for `RwCounter` and `RwcIncreaseLeft` column + - define `rw_diff` to be 1 if the `Type` is `Memory` or `TxLog` and `Padding` is 0 in the current row; otherwise 0. + - when it's not the last row in a copy event (`is_last == 0`), `RwCounter` increases by `rw_diff` and `RwcIncreaseLeft` decrases by `rw_diff`. + - when it's the last row in a copy event (`is_last == 1`), `RwcIncreaseLeft` is equal to `rw_diff`. + +Second, the circuit adds the constraints for every copy step in the circuit, when `q_step` is 1. + +- Look up the copy type pair `(Type, Type[1])` in a fixed table to make sure it's a valid copy step. +- Constrain the transition for `BytesLeft` + - when it's not the last step (`is_last[1] != 1`), decrease by 1 in the next step + - otherwise, equals to 1. +- Constrain the write value equals to read value: `Value[1] == Value` +- Constrain `Value == 0` when `Padding == 1`. +- Construct the LT gadget to compare `Address` and `AddressEnd` in the read operation. If `Address >= AddressEnd`, constrain `Padding == 1` +- Constrain `Padding[1] == 0` as the write operation is never padded. + +Third, the circuit adds the lookup arguments to the corresponding tables. + +- When `Type` is `Memory` or `is_memory == 1` and `Padding == 0`, look up the `Value` to `rw_table` with `Memory` tag. +- When `Type` is `TxCalldata` or `is_tx_calldata == 1` and `Padding == 0`, look up the `Value` to `tx_table`. +- When `Type` is `Bytecode` or `is_bytecode == 1` and `Padding == 0`, look up the `Value` and `IsCode` to `bytecode_table` +- When `Type` is `TxLog` or `is_tx_log == 1`, look up the `Value` to `rw_table` with `TxLog` tag. + +## Code + +Please refer to `src/zkevm-specs/copy_circuit.py` diff --git a/specs/opcode/CopyCodeToMemory.md b/specs/opcode/CopyCodeToMemory.md deleted file mode 100644 index c6b67abde..000000000 --- a/specs/opcode/CopyCodeToMemory.md +++ /dev/null @@ -1,31 +0,0 @@ -# CopyCodeToMemory - -## Circuit Behaviour - -`CopyCodeToMemory` is an internal execution state and doesn't correspond to an EVM opcode. It verifies that data from bytecode table has been written to memory. This gadget can in one iteration only copy `MAX_COPY_BYTES` number of bytes, hence for lengths longer than the bound the gadget loops itself until there are no more bytes to be copied. - -The `CopyCodeToMemory` circuit uses the `BufferReaderGadget` to check if the access is out of bounds and needs 0 padding. - -The `CopyCodeToMemory` circuit looks up the bytes read from buffer against both the bytecode table and the RW table (memory-write). An additional constraint checks whether or not the copying is finished, and if not, it constrains the next execution state to continue being `CopyCodeToMemory` while also adding constraints to the next step's auxiliary data. - -## Constraints - -We define `n_bytes_read` as the number of bytes read from the bytecode table. `n_bytes_read <= MAX_COPY_BYTES`. - -We define `n_bytes_written` as the number of bytes written to the memory. `n_bytes_written <= MAX_COPY_BYTES`. - -`n_bytes_read` differs from `n_bytes_written` in out-of-bound cases where nothing is read from the bytecode table but a `0` is written to memory. - -1. State Transition: - - rw_counter: `n_bytes_written` -2. Lookups: - - `n_bytes_read` lookups from bytecode table - - `n_bytes_written` lookups from RW table (memory-write) - -## Exceptions - -No exceptions for `CopyCodeToMemory` since it is an internal state. - -## Code - -Please refer to `src/zkevm_specs/evm/execution/copy_code_to_memory.py`. diff --git a/specs/opcode/CopyToLog.md b/specs/opcode/CopyToLog.md deleted file mode 100644 index f4b8a85e2..000000000 --- a/specs/opcode/CopyToLog.md +++ /dev/null @@ -1,42 +0,0 @@ -# CopyToLog - -## Circuit behaviour - -`CopyToLog` is an internal execution state and doesn't correspond to an EVM opcode. It copies -data from memory RW `Txlog` entries. This gadget needs to loop itself if it hasn't finished -the copy. - -The `CopyToLog` circuit uses gadget `BufferReaderGadget` like `CopyToMemory`. - -In the `CopyToLog` circuit, it needs to lookup the bytes that is read from the buffer according -to `BufferReaderGadget` in the memory. It needs to check if the copy is finished. If not, it constrains the -next execution state to still be `CopyToLog` and adds extra constraints on the states in the next -`CopyToLog`. - -## Constraints - -Define two auxiliary variables: - -- `nbytes_written`: the number of bytes written to `TxLog`. It could be either `MAX_COPY_BYTES` - or number of bytes left to copy. -- `nbytes_read`: the number of bytes read from memory. It's no more than - `nbytes_written`. `nbytes_read` is smaller than `nbytes_written` when there's out-of-bound - access to the src buffer. - -1. State transition: - - rw_counter - - from memory: + nbytes_read - - to txlog: + nbytes_written -2. Lookups: - - from memory: - - `nbytes_read` lookups from rw table (memory read) - - to txlog when is_persistent is true: - - `nbytes_written` lookups from rw table (TxLog write) - -## Exceptions - -No exceptions for `CopyToLog` as it's an internal state. - -## Code - -Please refer to `src/zkevm_specs/evm/execution/copy_to_log.py`. diff --git a/specs/opcode/CopyToMemory.md b/specs/opcode/CopyToMemory.md deleted file mode 100644 index cb0ecf975..000000000 --- a/specs/opcode/CopyToMemory.md +++ /dev/null @@ -1,52 +0,0 @@ -# CopyToMemory opcode - -## Circuit behaviour - -`CopyToMemory` is an internal execution state and doesn't correspond to an EVM opcode. It can copy -data from either tx or memory to memory. This gadget needs to loop itself if it hasn't finished -the copy. - -The `CopyToMemory` circuit uses another gadget `BufferReaderGadget` to check if the access is out of -bound and needs 0 padding. The `BufferReaderGadget` couples every bytes accessed with `selectors` -that indicate whether a byte has value, and `bound_dist` that is defined as -`max(0, addr_end - addr)`. When `bound_dist[i] == 0`, it indicates that the buffer access at index -`i` is out of bound and therefore needs to pad 0. For `bound_dist[0]`, we need to constrain the -value using `min` gadget. But we only need to limit on the difference between two consecutive -`bound_dist` values for the rest of `bound_dist` array as we know its value decreases by 1 each -time until 0. - -In the `CopyToMemory` circuit, it needs to lookup the bytes that is read from the buffer according -to `BufferReaderGadget` in the tx context table or memory table depending whether the src buffer -is from tx or memory. Last, it needs to check if the copy is finished. If not, it constrains the -next execution state to still be `CopyToMemory` and add extra constraints on the states in the next -`CopyToMemory`. - -## Constraints - -Define two auxiliary variables: - -- `nbytes_written`: the number of bytes written to the memory. It could be either `MAX_COPY_BYTES` - or number of bytes left to copy. -- `nbytes_read`: the number of bytes read from tx table or memory. It's no more than - `nbytes_written`. `nbytes_read` is smaller than `nbytes_written` when there's out-of-bound - access to the src buffer. - -1. State transition: - - rw_counter - - from tx: + nbytes_written - - from memory: + nbytes_written + nbytes_read -2. Lookups: nbytes_read + nbytes_written - - from tx: - - `nbytes_read` lookups from tx context table - - `nbytes_written` lookups from rw table (memory write to call_id) - - from memory: - - `nbytes_read` lookups from rw table (memory read from caller_id) - - `nbytes_written` lookups from rw table (memory write to call_id) - -## Exceptions - -No exceptions for `CopyToMemory` as it's an internal state. - -## Code - -Please refer to `src/zkevm_specs/evm/execution/memory_copy.py`. diff --git a/specs/tables.md b/specs/tables.md index d73cc1a80..b245b978e 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -271,3 +271,52 @@ Columns expressions in circuit: - Key: `key_rlc_mult` - ValuePrev: `is_nonce_mod * sel1 + is_balance_mod * sel2 + is_codehash_mod * sel2 + is_storage_mod * mult_diff` - ValueCur: `is_nonce_mod * s_mod_node_hash_rlc + is_balance_mod * c_mod_node_hash_rlc + is_codehash_mod * c_mod_node_hash_rlc + is_storage_mod * acc_c` + +## `copy_table` + +Proved by the copy circuit. + +The copy table consists of 13 columns, described as follows: + +- **q_step**: a fixed column for boolean value to indicate a copy step, always alternating between 1 and 0, where 1 indicates a read op and 0 indicates a write op. +- **is_first**: a boolean value to indicate the first row in a copy event. +- **is_last**: a boolean value to indicate the last row in a copy event. +- **ID**: could be `$txID`, `$callID`, `$codeHash` (RLC encoded). +- **Type**: indicates the type of data source, including `Memory`, `Bytecode`, `TxCalldata`, `TxLog`. +- **Address**: indicates the address in the source data, could be memory address, byte index in the bytecode, tx call data, and tx log data. When the data type is `TxLog`, the address is the combination of byte index, `TxLogFieldTag.Data` tag, and `LogID`. +- **AddressEnd**: indicates the address boundary of the source data. Any data read from address greater than or equal to `AddressEnd` should be 0. Note `AddressEnd` is only valid for read operations or `q_step` is 1. +- **BytesLeft**: indicates the number of bytes left to be copied. +- **Value**: indicates the value read or write from source or to the destination. +- **Pad**: indicates if the value read from the source is padded. Only valid for read operations or `q_step` is 1. +- **IsCode**: a boolean value to indicate if the `Value` is an executable opcode or the data portion of `PUSH*` operations. Only valid when `Type` is `Bytecode`. +- **RwCounter**: indicates the current RW counter at this row. This value will be used in the lookup to the `rw_table` when `Type` is `Memory` or `TxLog`. +- **RwcIncreaseLeft**: indicates how much the RW counter will increase in a copy event. + + +Unlike other lookup tables, the copy table is a virtual table. The lookup entry is not a single row in the table, and not every row corresponds to a lookup entry. +Instead, a lookup entry is constructed from the first two rows in each copy event as +`(is_first, ID, Type, ID[1], Type[1], Address, AddressEnd, Address[1], BytesLeft, RwCounter, RwcIncreaseLeft)`, where `is_first` is 1 and `Column[1]` indicates the next row in the corresponding column. + +The table below lists all of copy pairs supported in the copy table: +- Copy from Tx call data to memory (`CALLDATACOPY`). +- Copy from caller/callee memory to callee/caller memory (`CALLDATACOPY`, `RETURN` (not create), `RETURNDATACOPY`, `REVERT`). +- Copy from bytecode to memory (`CODECOPY`, `EXTCODECOPY`). +- Copy from memory to bytecode (`CREATE`, `CREATE2`, `RETURN` (create)) +- Copy from memory to TxLog in the `rw_table` (`LOGX`) + +| q_step | q_first | q_last | ID | Type | Address | AddressEnd | BytesLeft | Value | IsCode | Pad | RwCounter | RwcIncreaseLeft | +|--------|---------|--------|-----------|------------|----------------|----------------|------------|--------|---------|-----|-----------|-----------------| +| 1 | 0/1 | 0 | $txID | TxCalldata | $byteIndex | $cdLength | $bytesLeft | $value | - | 0/1 | - | $rwcIncLeft | +| 0 | 0 | 0/1 | $callID | Memory | $memoryAddress | - | - | $value | - | 0 | $counter | $rwcIncLeft | +| | | | | | | | | | | | | | +| 1 | 0/1 | 0 | $callID | Memory | $memoryAddress | $memoryAddress | $bytesLeft | $value | - | 0/1 | $counter | $rwcIncLeft | +| 0 | 0 | 0/1 | $callID | Memory | $memoryAddress | - | - | $value | - | 0 | $counter | $rwcIncLeft | +| | | | | | | | | | | | | | +| 1 | 0/1 | 0 | $callID | Memory | $memoryAddress | $memoryAddress | $bytesLeft | $value | $isCode | 0/1 | $counter | $rwcIncLeft | +| 0 | 0 | 0/1 | $codeHash | Bytecode | $byteIndex | - | - | $value | $isCode | 0 | - | $rwcIncLeft | +| | | | | | | | | | | | | | +| 1 | 0/1 | 0 | $codeHash | Bytecode | $byteIndex | $codeLength | $bytesLeft | $value | $isCode | 0/1 | - | $rwcIncLeft | +| 0 | 0 | 0/1 | $callID | Memory | $memoryAddress | - | - | $value | $isCode | 0 | $counter | $rwcIncLeft | +| | | | | | | | | | | | | | +| 1 | 0/1 | 0 | $callID | Memory | $memoryAddress | $memoryAddress | $bytesLeft | $value | - | 0/1 | $counter | $rwcIncLeft | +| 0 | 0 | 0/1 | $txID | TxLog | $byteIndex \|\| TxLogData \|\| $logID | - | - | $value | - | 0 | $counter | $rwcIncLeft | diff --git a/src/zkevm_specs/__init__.py b/src/zkevm_specs/__init__.py index 66e7e9ba3..80ed20e13 100644 --- a/src/zkevm_specs/__init__.py +++ b/src/zkevm_specs/__init__.py @@ -1,6 +1,9 @@ from . import bytecode +from . import copy_circuit from . import encoding from . import evm from . import opcode +from . import state +from . import tx from . import util from . import tx diff --git a/src/zkevm_specs/copy_circuit.py b/src/zkevm_specs/copy_circuit.py new file mode 100644 index 000000000..7d729689a --- /dev/null +++ b/src/zkevm_specs/copy_circuit.py @@ -0,0 +1,111 @@ +from typing import Dict, Iterator, List, NewType, Optional, Sequence, Union, Mapping, Tuple + +from .util import FQ, Expression, ConstraintSystem, cast_expr, MAX_N_BYTES, N_BYTES_MEMORY_ADDRESS +from .evm import ( + Tables, + CopyDataTypeTag, + CopyCircuitRow, + RW, + RWTableTag, + FixedTableTag, + CopyCircuit, + TxContextFieldTag, + BytecodeFieldTag, + TxLogFieldTag, +) + + +def lt(lhs: Expression, rhs: Expression, n_bytes: int) -> FQ: + assert n_bytes <= MAX_N_BYTES, "Too many bytes to composite an integer in field" + assert lhs.expr().n < 256**n_bytes, f"lhs {lhs} exceeds the range of {n_bytes} bytes" + assert rhs.expr().n < 256**n_bytes, f"rhs {rhs} exceeds the range of {n_bytes} bytes" + return FQ(lhs.expr().n < rhs.expr().n) + + +def verify_row(cs: ConstraintSystem, rows: Sequence[CopyCircuitRow]): + cs.constrain_bool(rows[0].is_first) + cs.constrain_bool(rows[0].is_last) + # is_first == 0 when q_step == 0 + cs.constrain_zero((1 - rows[0].q_step) * rows[0].is_first) + # is_last == 0 when q_step == 1 + cs.constrain_zero(rows[0].q_step * rows[0].is_last) + cs.constrain_equal(rows[0].is_memory, cs.is_zero(rows[0].tag - CopyDataTypeTag.Memory)) + cs.constrain_equal(rows[0].is_bytecode, cs.is_zero(rows[0].tag - CopyDataTypeTag.Bytecode)) + cs.constrain_equal(rows[0].is_tx_calldata, cs.is_zero(rows[0].tag - CopyDataTypeTag.TxCalldata)) + cs.constrain_equal(rows[0].is_tx_log, cs.is_zero(rows[0].tag - CopyDataTypeTag.TxLog)) + + # constrain the transition between two copy steps + is_last_two_rows = rows[0].is_last + rows[1].is_last + with cs.condition(1 - is_last_two_rows) as cs: + # not last two rows + cs.constrain_equal(rows[0].id, rows[2].id) + cs.constrain_equal(rows[0].tag, rows[2].tag) + cs.constrain_equal(rows[0].addr + 1, rows[2].addr) + cs.constrain_equal(rows[0].src_addr_end, rows[2].src_addr_end) + + # contrain the transition for `rw_counter` and `rwc_inc_left` + rw_diff = (1 - rows[0].is_pad) * (rows[0].is_memory + rows[0].is_tx_log) + with cs.condition(1 - rows[0].is_last) as cs: + # not last row + cs.constrain_equal(rows[0].rw_counter + rw_diff, rows[1].rw_counter) + cs.constrain_equal(rows[0].rwc_inc_left - rw_diff, rows[1].rwc_inc_left) + with cs.condition(rows[0].is_last) as cs: + # rwc_inc_left == rw_diff for last row in the copy slot + cs.constrain_equal(rows[0].rwc_inc_left, rw_diff) + + +def verify_step(cs: ConstraintSystem, rows: Sequence[CopyCircuitRow]): + with cs.condition(rows[0].q_step): + # bytes_left == 1 for last step + cs.constrain_zero(rows[1].is_last * (1 - rows[0].bytes_left)) + # bytes_left == bytes_left_next + 1 for non-last step + cs.constrain_zero((1 - rows[1].is_last) * (rows[0].bytes_left - rows[2].bytes_left - 1)) + # write value == read value + cs.constrain_equal(rows[0].value, rows[1].value) + # value == 0 when is_pad == 1 for read + cs.constrain_zero(rows[0].is_pad * rows[0].value) + # is_pad == 1 - (src_addr < src_addr_end) for read row + cs.constrain_equal( + 1 - lt(rows[0].addr, rows[0].src_addr_end, N_BYTES_MEMORY_ADDRESS), rows[0].is_pad + ) + # is_pad == 0 for write row + cs.constrain_zero(rows[1].is_pad) + + +def verify_copy_table(copy_circuit: CopyCircuit, tables: Tables): + cs = ConstraintSystem() + copy_table = copy_circuit.table() + n = len(copy_table) + for i, row in enumerate(copy_table): + rows = [ + row, + copy_table[(i + 1) % n], + copy_table[(i + 2) % n], + ] + # constrain on each row and step + verify_row(cs, rows) + verify_step(cs, rows) + + # lookup into tables + if row.is_memory == 1 and row.is_pad == 0: + val = tables.rw_lookup( + row.rw_counter, 1 - row.q_step, FQ(RWTableTag.Memory), row.id, row.addr + ).value + cs.constrain_equal(cast_expr(val, FQ), row.value) + if row.is_bytecode == 1 and row.is_pad == 0: + val = tables.bytecode_lookup( + row.id, FQ(BytecodeFieldTag.Byte), row.addr, row.is_code + ).value + cs.constrain_equal(cast_expr(val, FQ), row.value) + if row.is_tx_calldata == 1 and row.is_pad == 0: + val = tables.tx_lookup(row.id, FQ(TxContextFieldTag.CallData), row.addr).value + cs.constrain_equal(val, row.value) + if row.is_tx_log == 1: + val = tables.rw_lookup( + row.rw_counter, + FQ(RW.Write), + FQ(RWTableTag.TxLog), + row.id, # tx_id + row.addr, + ).value + cs.constrain_equal(cast_expr(val, FQ), row.value) diff --git a/src/zkevm_specs/evm/execution/__init__.py b/src/zkevm_specs/evm/execution/__init__.py index 6e236cbfe..4eeeff231 100644 --- a/src/zkevm_specs/evm/execution/__init__.py +++ b/src/zkevm_specs/evm/execution/__init__.py @@ -3,11 +3,8 @@ from ..execution_state import ExecutionState from .begin_tx import * -from .copy_code_to_memory import * from .end_tx import * from .end_block import * -from .memory_copy import * -from .copy_to_log import * # Opcode's successful cases from .add_sub import * @@ -46,8 +43,6 @@ ExecutionState.BeginTx: begin_tx, ExecutionState.EndTx: end_tx, ExecutionState.EndBlock: end_block, - ExecutionState.CopyCodeToMemory: copy_code_to_memory, - ExecutionState.CopyToMemory: copy_to_memory, ExecutionState.ADD: add_sub, ExecutionState.ADDMOD: addmod, ExecutionState.MULMOD: mulmod, @@ -72,7 +67,6 @@ ExecutionState.SELFBALANCE: selfbalance, ExecutionState.GASPRICE: gasprice, ExecutionState.EXTCODEHASH: extcodehash, - ExecutionState.CopyToLog: copy_to_log, ExecutionState.LOG: log, ExecutionState.CALL: call, ExecutionState.ISZERO: iszero, diff --git a/src/zkevm_specs/evm/execution/calldatacopy.py b/src/zkevm_specs/evm/execution/calldatacopy.py index 186271bc9..b225e2791 100644 --- a/src/zkevm_specs/evm/execution/calldatacopy.py +++ b/src/zkevm_specs/evm/execution/calldatacopy.py @@ -1,7 +1,7 @@ from ...util import N_BYTES_MEMORY_ADDRESS, FQ, Expression from ..execution_state import ExecutionState from ..instruction import Instruction, Transition -from ..table import RW, CallContextFieldTag, TxContextFieldTag +from ..table import RW, CallContextFieldTag, TxContextFieldTag, CopyDataTypeTag def calldatacopy(instruction: Instruction): @@ -35,22 +35,27 @@ def calldatacopy(instruction: Instruction): ) gas_cost = instruction.memory_copier_gas_cost(length, memory_expansion_gas_cost) - # When length != 0, constrain the state in the next execution state CopyToMemory - if instruction.is_zero(length) == FQ(0): - instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) - next_aux = instruction.next.aux_data - instruction.constrain_equal(next_aux.src_addr, data_offset + call_data_offset) - instruction.constrain_equal(next_aux.dst_addr, memory_offset) - instruction.constrain_equal( - next_aux.src_addr_end, call_data_length.expr() + call_data_offset + src_type = instruction.select( + FQ(instruction.curr.is_root), FQ(CopyDataTypeTag.TxCalldata), FQ(CopyDataTypeTag.Memory) + ) + if instruction.is_zero(length) == 0: + copy_rwc_inc = instruction.copy_lookup( + src_id, + CopyDataTypeTag(src_type.n), + instruction.curr.call_id, + CopyDataTypeTag.Memory, + call_data_offset.expr() + data_offset.expr(), + call_data_offset.expr() + call_data_length.expr(), + memory_offset, + length, + instruction.curr.rw_counter + instruction.rw_counter_offset, ) - instruction.constrain_equal(next_aux.from_tx, FQ(instruction.curr.is_root)) - instruction.constrain_equal(next_aux.src_id, src_id) - instruction.constrain_equal(next_aux.bytes_left, length) + else: + copy_rwc_inc = FQ(0) instruction.step_state_transition_in_same_context( opcode, - rw_counter=Transition.delta(instruction.rw_counter_offset), + rw_counter=Transition.delta(instruction.rw_counter_offset + copy_rwc_inc), program_counter=Transition.delta(1), stack_pointer=Transition.delta(3), memory_size=Transition.to(next_memory_size), diff --git a/src/zkevm_specs/evm/execution/codecopy.py b/src/zkevm_specs/evm/execution/codecopy.py index 6790c5d1a..60486c408 100644 --- a/src/zkevm_specs/evm/execution/codecopy.py +++ b/src/zkevm_specs/evm/execution/codecopy.py @@ -1,8 +1,7 @@ from ...util import N_BYTES_MEMORY_ADDRESS, FQ from ..execution_state import ExecutionState from ..instruction import Instruction, Transition -from ..step import CopyCodeToMemoryAuxData -from ..table import RW, RWTableTag, CallContextFieldTag, AccountFieldTag +from ..table import RW, RWTableTag, CallContextFieldTag, AccountFieldTag, CopyDataTypeTag def codecopy(instruction: Instruction): @@ -25,20 +24,23 @@ def codecopy(instruction: Instruction): gas_cost = instruction.memory_copier_gas_cost(size, memory_expansion_gas_cost) if instruction.is_zero(size) == FQ(0): - instruction.constrain_equal( - instruction.next.execution_state, ExecutionState.CopyCodeToMemory + copy_rwc_inc = instruction.copy_lookup( + instruction.curr.code_hash, + CopyDataTypeTag.Bytecode, + instruction.curr.call_id, + CopyDataTypeTag.Memory, + code_offset, + code_size, + memory_offset, + size, + instruction.curr.rw_counter + instruction.rw_counter_offset, ) - next_aux = instruction.next.aux_data - assert isinstance(next_aux, CopyCodeToMemoryAuxData) - instruction.constrain_equal(next_aux.src_addr, code_offset) - instruction.constrain_equal(next_aux.dst_addr, memory_offset) - instruction.constrain_equal(next_aux.src_addr_end, code_size) - instruction.constrain_equal(next_aux.bytes_left, size) - instruction.constrain_equal(next_aux.code_hash, instruction.curr.code_hash) + else: + copy_rwc_inc = FQ(0) instruction.step_state_transition_in_same_context( opcode, - rw_counter=Transition.delta(instruction.rw_counter_offset), + rw_counter=Transition.delta(instruction.rw_counter_offset + copy_rwc_inc), program_counter=Transition.delta(1), stack_pointer=Transition.delta(3), memory_size=Transition.to(next_memory_size), diff --git a/src/zkevm_specs/evm/execution/copy_code_to_memory.py b/src/zkevm_specs/evm/execution/copy_code_to_memory.py deleted file mode 100644 index d39da5ef9..000000000 --- a/src/zkevm_specs/evm/execution/copy_code_to_memory.py +++ /dev/null @@ -1,61 +0,0 @@ -import itertools -from typing import Iterator - -from ...util import FQ, MAX_N_BYTES_COPY_CODE_TO_MEMORY, N_BYTES_MEMORY_SIZE, RLC -from ..execution_state import ExecutionState -from ..instruction import Instruction, Transition -from ..step import CopyCodeToMemoryAuxData -from ..table import RW -from ..util import BufferReaderGadget - - -def copy_code_to_memory(instruction: Instruction): - aux = instruction.curr.aux_data - assert isinstance(aux, CopyCodeToMemoryAuxData) - - buffer_reader = BufferReaderGadget( - instruction, MAX_N_BYTES_COPY_CODE_TO_MEMORY, aux.src_addr, aux.src_addr_end, aux.bytes_left - ) - - for idx in range(MAX_N_BYTES_COPY_CODE_TO_MEMORY): - if buffer_reader.read_flag(idx) == 1: - byte = instruction.bytecode_lookup( - aux.code_hash, - aux.src_addr + idx, - ) - buffer_reader.constrain_byte(idx, byte) - - for idx in range(MAX_N_BYTES_COPY_CODE_TO_MEMORY): - if buffer_reader.has_data(idx) == 1: - byte = instruction.memory_lookup(RW.Write, aux.dst_addr + idx) - buffer_reader.constrain_byte(idx, byte) - - copied_bytes = buffer_reader.num_bytes() - lt, finished = instruction.compare(copied_bytes, aux.bytes_left, N_BYTES_MEMORY_SIZE) - - # either copied bytes are less than the bytes left, or copying is finished - instruction.constrain_zero((1 - lt) * (1 - finished)) - - if finished == 0: - instruction.constrain_equal( - instruction.next.execution_state, ExecutionState.CopyCodeToMemory - ) - next_aux = instruction.next.aux_data - assert next_aux is not None and isinstance(next_aux, CopyCodeToMemoryAuxData) - instruction.constrain_equal(next_aux.src_addr, aux.src_addr + copied_bytes) - instruction.constrain_equal(next_aux.dst_addr, aux.dst_addr + copied_bytes) - instruction.constrain_equal(next_aux.bytes_left + copied_bytes, aux.bytes_left) - instruction.constrain_equal(next_aux.src_addr_end, aux.src_addr_end) - instruction.constrain_equal(next_aux.code_hash, aux.code_hash) - - instruction.constrain_step_state_transition( - rw_counter=Transition.delta(instruction.rw_counter_offset), - call_id=Transition.same(), - is_root=Transition.same(), - is_create=Transition.same(), - code_hash=Transition.same(), - program_counter=Transition.same(), - stack_pointer=Transition.same(), - memory_size=Transition.same(), - reversible_write_counter=Transition.same(), - ) diff --git a/src/zkevm_specs/evm/execution/copy_to_log.py b/src/zkevm_specs/evm/execution/copy_to_log.py deleted file mode 100644 index b4aab4d3c..000000000 --- a/src/zkevm_specs/evm/execution/copy_to_log.py +++ /dev/null @@ -1,56 +0,0 @@ -from ...util import FQ, N_BYTES_MEMORY_SIZE -from ..execution_state import ExecutionState -from ..instruction import Instruction, Transition -from ..step import CopyToLogAuxData -from ..table import RW, TxLogFieldTag, CallContextFieldTag -from ..util import BufferReaderGadget -from ...util import MAX_COPY_BYTES - - -def copy_to_log(instruction: Instruction): - aux = instruction.curr.aux_data - assert isinstance(aux, CopyToLogAuxData) - - buffer_reader = BufferReaderGadget( - instruction, MAX_COPY_BYTES, aux.src_addr, aux.src_addr_end, aux.bytes_left - ) - - for i in range(MAX_COPY_BYTES): - if buffer_reader.read_flag(i) == 0: - byte = FQ.zero() - else: - byte = instruction.memory_lookup(RW.Read, aux.src_addr + i) - buffer_reader.constrain_byte(i, byte) - # when is_persistent = false, only do memory_lookup, no tx_log_lookup - if buffer_reader.has_data(i) == 1 and aux.is_persistent == 1: - instruction.constrain_equal( - byte, - instruction.tx_log_lookup( - aux.tx_id, - instruction.curr.log_id, - TxLogFieldTag.Data, - i + aux.data_start_index.n, - ), - ) - - copied_bytes = buffer_reader.num_bytes() - lt, finished = instruction.compare(copied_bytes, aux.bytes_left, N_BYTES_MEMORY_SIZE) - # constrain lt == 1 or finished == 1 - instruction.constrain_zero((1 - lt) * (1 - finished)) - - if finished == 0: - instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToLog) - next_aux = instruction.next.aux_data - assert isinstance(next_aux, CopyToLogAuxData) - instruction.constrain_equal(next_aux.src_addr, aux.src_addr + copied_bytes) - instruction.constrain_equal(next_aux.bytes_left + copied_bytes, aux.bytes_left) - instruction.constrain_equal(next_aux.src_addr_end, aux.src_addr_end) - instruction.constrain_equal(next_aux.is_persistent, aux.is_persistent) - instruction.constrain_equal(next_aux.tx_id, aux.tx_id) - instruction.constrain_equal( - next_aux.data_start_index, aux.data_start_index + MAX_COPY_BYTES - ) - - instruction.constrain_step_state_transition( - rw_counter=Transition.delta(instruction.rw_counter_offset), - ) diff --git a/src/zkevm_specs/evm/execution/log.py b/src/zkevm_specs/evm/execution/log.py index c68e47dbc..e6b9da75c 100644 --- a/src/zkevm_specs/evm/execution/log.py +++ b/src/zkevm_specs/evm/execution/log.py @@ -1,8 +1,8 @@ from ..instruction import Instruction, Transition -from ..table import CallContextFieldTag, TxLogFieldTag, TxContextFieldTag +from ..table import CallContextFieldTag, TxLogFieldTag, TxContextFieldTag, CopyDataTypeTag from ..opcode import Opcode from ..execution_state import ExecutionState -from ...util.param import GAS_COST_LOG +from ...util.param import GAS_COST_LOG, GAS_COST_LOGDATA from ...util import FQ, cast_expr @@ -62,17 +62,21 @@ def log(instruction: Instruction): diff = topic_selectors[i - 1] - topic_selectors[i] instruction.constrain_bool(FQ(diff)) - # check memory copy, should do in next step here - # When length != 0, constrain the state in the next execution state CopyToLog - if not instruction.is_zero(msize): - instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToLog) - next_aux = instruction.next.aux_data - instruction.constrain_equal(next_aux.src_addr, mstart) - instruction.constrain_equal(next_aux.src_addr_end, mstart + msize) - instruction.constrain_equal(next_aux.bytes_left, msize) - instruction.constrain_equal(next_aux.is_persistent, is_persistent) - instruction.constrain_equal(next_aux.tx_id, tx_id) - instruction.constrain_zero(next_aux.data_start_index) + if instruction.is_zero(msize) == 0 and is_persistent == 1: + copy_rwc_inc = instruction.copy_lookup( + instruction.curr.call_id, + CopyDataTypeTag.Memory, + tx_id, + CopyDataTypeTag.TxLog, + mstart, + mstart + msize, + FQ(0), + msize, + instruction.curr.rw_counter + instruction.rw_counter_offset, + log_id=instruction.curr.log_id + 1, + ) + else: + copy_rwc_inc = FQ(0) # omit block number constraint even it is set within op code explicitly, because by default the circuit only handle # current block, otherwise, block context lookup is required. @@ -81,13 +85,16 @@ def log(instruction: Instruction): mstart, msize ) dynamic_gas = ( - GAS_COST_LOG + GAS_COST_LOG * (opcode - Opcode.LOG0) + 8 * msize + memory_expansion_gas + GAS_COST_LOG + + GAS_COST_LOG * (opcode - Opcode.LOG0) + + GAS_COST_LOGDATA * msize + + memory_expansion_gas ) assert isinstance(is_persistent, FQ) instruction.step_state_transition_in_same_context( opcode, - rw_counter=Transition.delta(instruction.rw_counter_offset), + rw_counter=Transition.delta(instruction.rw_counter_offset + copy_rwc_inc), program_counter=Transition.delta(1), stack_pointer=Transition.delta(2 + opcode - Opcode.LOG0), dynamic_gas_cost=dynamic_gas, diff --git a/src/zkevm_specs/evm/execution/memory_copy.py b/src/zkevm_specs/evm/execution/memory_copy.py deleted file mode 100644 index 0157ee464..000000000 --- a/src/zkevm_specs/evm/execution/memory_copy.py +++ /dev/null @@ -1,56 +0,0 @@ -from ...util import MAX_N_BYTES_COPY_TO_MEMORY, N_BYTES_MEMORY_SIZE, FQ, Expression -from ..execution_state import ExecutionState -from ..instruction import Instruction, Transition -from ..step import CopyToMemoryAuxData -from ..table import RW -from ..util import BufferReaderGadget - - -def copy_to_memory(instruction: Instruction): - aux = instruction.curr.aux_data - assert isinstance(aux, CopyToMemoryAuxData) - - buffer_reader = BufferReaderGadget( - instruction, MAX_N_BYTES_COPY_TO_MEMORY, aux.src_addr, aux.src_addr_end, aux.bytes_left - ) - - for i in range(MAX_N_BYTES_COPY_TO_MEMORY): - if buffer_reader.read_flag(i) == 0: - byte: Expression = FQ(0) - elif aux.from_tx == 1: - byte = instruction.tx_calldata_lookup(aux.src_id, aux.src_addr + i) - else: - byte = instruction.memory_lookup(RW.Read, aux.src_addr + i, call_id=aux.src_id) - buffer_reader.constrain_byte(i, byte) - if buffer_reader.has_data(i) == 1: - instruction.constrain_equal(byte, instruction.memory_lookup(RW.Write, aux.dst_addr + i)) - - copied_bytes = buffer_reader.num_bytes() - lt, finished = instruction.compare(copied_bytes, aux.bytes_left, N_BYTES_MEMORY_SIZE) - # constrain lt == 1 or finished == 1 - instruction.constrain_zero((1 - lt) * (1 - finished)) - - if finished == 0: - next_aux = instruction.next.aux_data - - assert isinstance(next_aux, CopyToMemoryAuxData) - - instruction.constrain_equal(instruction.next.execution_state, ExecutionState.CopyToMemory) - instruction.constrain_equal(next_aux.src_addr, aux.src_addr + copied_bytes) - instruction.constrain_equal(next_aux.dst_addr, aux.dst_addr + copied_bytes) - instruction.constrain_equal(next_aux.bytes_left + copied_bytes, aux.bytes_left) - instruction.constrain_equal(next_aux.src_addr_end, aux.src_addr_end) - instruction.constrain_equal(next_aux.from_tx, aux.from_tx) - instruction.constrain_equal(next_aux.src_id, aux.src_id) - - instruction.constrain_step_state_transition( - rw_counter=Transition.delta(instruction.rw_counter_offset), - call_id=Transition.same(), - is_root=Transition.same(), - is_create=Transition.same(), - code_hash=Transition.same(), - program_counter=Transition.same(), - stack_pointer=Transition.same(), - memory_size=Transition.same(), - reversible_write_counter=Transition.same(), - ) diff --git a/src/zkevm_specs/evm/execution_state.py b/src/zkevm_specs/evm/execution_state.py index b2da834f6..b61d39600 100644 --- a/src/zkevm_specs/evm/execution_state.py +++ b/src/zkevm_specs/evm/execution_state.py @@ -19,9 +19,6 @@ class ExecutionState(IntEnum): BeginTx = auto() EndTx = auto() EndBlock = auto() - CopyToMemory = auto() - CopyToLog = auto() - CopyCodeToMemory = auto() # Opcode's successful cases STOP = auto() diff --git a/src/zkevm_specs/evm/instruction.py b/src/zkevm_specs/evm/instruction.py index fb251747f..01b45aab5 100644 --- a/src/zkevm_specs/evm/instruction.py +++ b/src/zkevm_specs/evm/instruction.py @@ -34,6 +34,7 @@ RWTableTag, TxLogFieldTag, TxReceiptFieldTag, + CopyDataTypeTag, ) @@ -154,8 +155,6 @@ def constrain_execution_state_transition(self): assert curr.halts() or curr == ExecutionState.BeginTx elif next == ExecutionState.EndBlock: assert curr in [ExecutionState.EndTx, ExecutionState.EndBlock] - elif next == ExecutionState.CopyToMemory: - assert curr in [ExecutionState.CopyToMemory, ExecutionState.CALLDATACOPY] def constrain_step_state_transition(self, **kwargs: Transition): keys = set( @@ -614,7 +613,7 @@ def tx_log_lookup( RW.Write, RWTableTag.TxLog, key1=tx_id, - key2=FQ(index + int(field_tag) << 32 + log_id.expr().n << 48), + key2=FQ(index + (int(field_tag) << 32) + (log_id.expr().n << 48)), key3=FQ(0), key4=FQ(0), ).value @@ -993,3 +992,29 @@ def memory_copier_gas_cost( def pow2_lookup(self, value: Expression, value_pow: Expression): self.fixed_lookup(FixedTableTag.Pow2, value, value_pow) + + def copy_lookup( + self, + src_id: Expression, + src_type: CopyDataTypeTag, + dst_id: Expression, + dst_type: CopyDataTypeTag, + src_addr: Expression, + src_addr_end: Expression, + dst_addr: Expression, + length: Expression, + rw_counter: Expression, + log_id: Expression = None, + ) -> FQ: + return self.tables.copy_lookup( + src_id, + FQ(src_type), + dst_id, + FQ(dst_type), + src_addr, + src_addr_end, + dst_addr, + length, + rw_counter, + log_id, + ).rwc_inc diff --git a/src/zkevm_specs/evm/step.py b/src/zkevm_specs/evm/step.py index da5038d84..8c62981e1 100644 --- a/src/zkevm_specs/evm/step.py +++ b/src/zkevm_specs/evm/step.py @@ -74,75 +74,3 @@ def __init__( self.reversible_write_counter = FQ(reversible_write_counter) self.log_id = FQ(log_id) self.aux_data = aux_data - - -class CopyToMemoryAuxData: - src_addr: FQ - dst_addr: FQ - bytes_left: FQ - src_addr_end: FQ - from_tx: FQ - src_id: FQ - - def __init__( - self, - src_addr: int, - dst_addr: int, - bytes_left: int, - src_addr_end: int, - from_tx: bool, - src_id: int, - ): - self.src_addr = FQ(src_addr) - self.dst_addr = FQ(dst_addr) - self.bytes_left = FQ(bytes_left) - self.src_addr_end = FQ(src_addr_end) - self.from_tx = FQ(from_tx) - self.src_id = FQ(src_id) - - -class CopyToLogAuxData: - src_addr: FQ - bytes_left: FQ - src_addr_end: FQ - is_persistent: FQ - tx_id: FQ - data_start_index: FQ - - def __init__( - self, - src_addr: int, - bytes_left: int, - src_addr_end: int, - is_persistent: int, - tx_id: int, - data_start_index: int, - ): - self.src_addr = FQ(src_addr) - self.bytes_left = FQ(bytes_left) - self.src_addr_end = FQ(src_addr_end) - self.is_persistent = FQ(is_persistent) - self.tx_id = FQ(tx_id) - self.data_start_index = FQ(data_start_index) - - -class CopyCodeToMemoryAuxData: - src_addr: FQ - dst_addr: FQ - bytes_left: FQ - src_addr_end: FQ - code_hash: RLC - - def __init__( - self, - src_addr: int, - dst_addr: int, - bytes_left: int, - src_addr_end: int, - code_hash: RLC, - ): - self.src_addr = FQ(src_addr) - self.dst_addr = FQ(dst_addr) - self.bytes_left = FQ(bytes_left) - self.src_addr_end = FQ(src_addr_end) - self.code_hash = code_hash diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 43623647b..9bf660be8 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -279,6 +279,17 @@ class TxReceiptFieldTag(IntEnum): LogLength = auto() +class CopyDataTypeTag(IntEnum): + """ + Tag for CopyTable that specifies the type of data source. + """ + + Bytecode = auto() + Memory = auto() + TxCalldata = auto() + TxLog = auto() + + class WrongQueryKey(Exception): def __init__(self, table_name: str, diff: Set[str]) -> None: self.message = f"Lookup {table_name} with invalid keys {diff}" @@ -366,6 +377,41 @@ class MPTTableRow(TableRow): value_prev: Expression +@dataclass +class CopyCircuitRow(TableRow): + q_step: FQ + is_first: FQ + is_last: FQ + id: FQ # one of call_id, bytecode_hash, tx_id + tag: FQ # CopyDataTypeTag + addr: FQ + src_addr_end: FQ + bytes_left: FQ + value: FQ + is_code: FQ + is_pad: FQ + rw_counter: FQ + rwc_inc_left: FQ + is_memory: FQ + is_bytecode: FQ + is_tx_calldata: FQ + is_tx_log: FQ + + +@dataclass(frozen=True) +class CopyTableRow(TableRow): + src_id: FQ + src_type: FQ + dst_id: FQ + dst_type: FQ + src_addr: FQ + src_addr_end: FQ + dst_addr: FQ + length: FQ + rw_counter: FQ + rwc_inc: FQ + + class Tables: """ A collection of lookup tables used in EVM circuit. @@ -376,6 +422,7 @@ class Tables: tx_table: Set[TxTableRow] bytecode_table: Set[BytecodeTableRow] rw_table: Set[RWTableRow] + copy_table: Set[CopyTableRow] def __init__( self, @@ -383,6 +430,7 @@ def __init__( tx_table: Set[TxTableRow], bytecode_table: Set[BytecodeTableRow], rw_table: Union[Set[Sequence[Expression]], Set[RWTableRow]], + copy_circuit: Sequence[CopyCircuitRow] = None, ) -> None: self.block_table = block_table self.tx_table = tx_table @@ -391,6 +439,32 @@ def __init__( row if isinstance(row, RWTableRow) else RWTableRow(*row) # type: ignore # (RWTableRow input args) for row in rw_table ) + if copy_circuit is not None: + self.copy_table = self._convert_copy_circuit_to_table(copy_circuit) + + def _convert_copy_circuit_to_table(self, copy_circuit: Sequence[CopyCircuitRow]): + rows = [] + for i, row in enumerate(copy_circuit): + if row.is_first != 1: + continue + assert i + 1 < len(copy_circuit), "Not enough rows in copy circuit" + next_row = copy_circuit[i + 1] + assert next_row.q_step == 0, "Invalid copy circuit" + rows.append( + CopyTableRow( + src_id=row.id, + src_type=row.tag, + dst_id=next_row.id, + dst_type=next_row.tag, + src_addr=row.addr, + src_addr_end=row.src_addr_end, + dst_addr=next_row.addr, + length=row.bytes_left, + rw_counter=row.rw_counter, + rwc_inc=row.rwc_inc_left, + ) + ) + return set(rows) def fixed_lookup( self, @@ -468,6 +542,35 @@ def rw_lookup( } return lookup(RWTableRow, self.rw_table, query) + def copy_lookup( + self, + src_id: Expression, + src_type: Expression, + dst_id: Expression, + dst_type: Expression, + src_addr: Expression, + src_addr_end: Expression, + dst_addr: Expression, + length: Expression, + rw_counter: Expression, + log_id: Expression = None, + ) -> CopyTableRow: + if dst_type == CopyDataTypeTag.TxLog: + assert log_id is not None + dst_addr = dst_addr + FQ(int(TxLogFieldTag.Data) << 32) + FQ(log_id.expr().n << 48) + query = { + "src_id": src_id, + "src_type": src_type, + "dst_id": dst_id, + "dst_type": dst_type, + "src_addr": src_addr, + "src_addr_end": src_addr_end, + "dst_addr": dst_addr, + "length": length, + "rw_counter": rw_counter, + } + return lookup(CopyTableRow, self.copy_table, query) + T = TypeVar("T", bound=TableRow) diff --git a/src/zkevm_specs/evm/typing.py b/src/zkevm_specs/evm/typing.py index c3817eec0..6da1cce26 100644 --- a/src/zkevm_specs/evm/typing.py +++ b/src/zkevm_specs/evm/typing.py @@ -1,5 +1,17 @@ from __future__ import annotations -from typing import Dict, Iterator, List, NewType, Optional, Sequence, Union +from typing import ( + cast, + Dict, + Iterator, + List, + MutableSequence, + NewType, + Optional, + Sequence, + Union, + Mapping, + Tuple, +) from functools import reduce from itertools import chain @@ -30,6 +42,9 @@ TxLogFieldTag, TxReceiptFieldTag, TxTableRow, + CopyDataTypeTag, + CopyCircuitRow, + CopyTableRow, ) from .opcode import get_push_size, Opcode @@ -197,11 +212,25 @@ def table_assignments(self, randomness: FQ) -> Iterator[TxTableRow]: ) +def init_is_code(code: bytearray) -> MutableSequence[bool]: + is_codes = [] + push_data_left = 0 + for idx in range(0, len(code)): + is_code = push_data_left == 0 + push_data_left = get_push_size(code[idx]) if is_code else push_data_left - 1 + is_codes.append(is_code) + return is_codes + + class Bytecode: code: bytearray + is_code: MutableSequence[bool] - def __init__(self, code: Optional[bytearray] = None) -> None: + def __init__( + self, code: Optional[bytearray] = None, is_code: Optional[MutableSequence[bool]] = None + ) -> None: self.code = bytearray() if code is None else code + self.is_code = init_is_code(self.code) if is_code is None else is_code def __getattr__(self, name: str): def method(*args) -> Bytecode: @@ -216,11 +245,13 @@ def method(*args) -> Bytecode: elif opcode.is_dup() or opcode.is_swap(): assert len(args) == 0 self.code.append(opcode) + self.is_code.append(True) else: assert len(args) <= 1024 - opcode.max_stack_pointer() for arg in reversed(args): self.push(arg, 32) self.code.append(opcode) + self.is_code.append(True) return self @@ -242,7 +273,9 @@ def push(self, value: Union[int, str, bytes, bytearray, RLC], n_bytes: int = 32) opcode = Opcode.PUSH1 + n_bytes - 1 self.code.append(opcode) + self.is_code.append(True) self.code.extend(value.rjust(n_bytes, b"\x00")) + self.is_code.extend([False] * n_bytes) return self @@ -252,15 +285,16 @@ def hash(self) -> U256: def table_assignments(self, randomness: FQ) -> Iterator[BytecodeTableRow]: class BytecodeIterator: idx: int - push_data_left: int hash: FQ code: bytes + is_code: Sequence[bool] - def __init__(self, hash: FQ, code: bytes): + def __init__(self, hash: FQ, code: bytes, is_code: Sequence[bool]): self.idx = 0 - self.push_data_left = 0 self.hash = hash self.code = code + self.is_code = is_code + assert len(code) == len(is_code) def __iter__(self): return self @@ -279,14 +313,13 @@ def __next__(self): # the other rows represent each byte in the bytecode idx = self.idx - 1 byte = self.code[idx] - is_code = self.push_data_left == 0 - self.push_data_left = get_push_size(byte) if is_code else self.push_data_left - 1 + is_code = self.is_code[idx] self.idx += 1 return BytecodeTableRow( self.hash, FQ(BytecodeFieldTag.Byte), FQ(idx), FQ(is_code), FQ(byte) ) - return BytecodeIterator(RLC(self.hash(), randomness).expr(), self.code) + return BytecodeIterator(RLC(self.hash(), randomness).expr(), self.code, self.is_code) Storage = NewType("Storage", Dict[U256, U256]) @@ -376,7 +409,7 @@ def tx_log_write( tx_id: IntOrFQ, log_id: int, field_tag: TxLogFieldTag, - index: int, + index: IntOrFQ, value: Union[int, FQ, RLC], ) -> RWDictionary: if isinstance(value, int): @@ -385,7 +418,7 @@ def tx_log_write( RW.Write, RWTableTag.TxLog, key1=FQ(tx_id), - key2=FQ(index + int(field_tag) << 32 + log_id << 48), + key2=FQ(index + (int(field_tag) << 32) + (log_id << 48)), key3=FQ(0), key4=FQ(0), value=value, @@ -615,3 +648,141 @@ def _append( ) return self + + +class CopyCircuit: + rows: List[CopyCircuitRow] + pad_rows: List[CopyCircuitRow] + + def __init__(self) -> None: + self.rows = [] + self.pad_rows = [CopyCircuitRow(FQ(1), *[FQ(0)] * 16), CopyCircuitRow(*[FQ(0)] * 17)] + + def table(self) -> Sequence[CopyCircuitRow]: + return self.rows + self.pad_rows + + def copy( + self, + rw_dict: RWDictionary, + src_id: IntOrFQ, + src_type: CopyDataTypeTag, + dst_id: IntOrFQ, + dst_type: CopyDataTypeTag, + src_addr: IntOrFQ, + src_addr_end: IntOrFQ, + dst_addr: IntOrFQ, + copy_length: IntOrFQ, + src_data: Mapping[IntOrFQ, Union[IntOrFQ, Tuple[IntOrFQ, IntOrFQ]]], + log_id: int = 0, + ): + new_rows: List[CopyCircuitRow] = [] + for i in range(int(copy_length)): + if int(src_addr + i) < int(src_addr_end): + is_pad = False + assert src_addr + i in src_data, f"Cannot find data at the offset {src_addr+i}" + value = src_data[src_addr + i] + if src_type == CopyDataTypeTag.Bytecode: + value = cast(Tuple[IntOrFQ, IntOrFQ], value) + value, is_code = value + else: + value = cast(IntOrFQ, value) + is_code = FQ(0) + value = FQ(value) + is_code = FQ(is_code) + else: + is_pad = True + value = FQ(0) + is_code = FQ(0) + + # read row, because TxLog is write-only, no need to feed log_id in the read row + self._append_row( + new_rows, + rw_dict, + False, + i == 0, + False, + src_id, + src_type, + src_addr + i, + value, + is_code, + is_pad, + src_addr_end=src_addr_end, + bytes_left=copy_length - i, + ) + + # write row + self._append_row( + new_rows, + rw_dict, + True, + False, + i == copy_length - 1, + dst_id, + dst_type, + dst_addr + i, + value, + is_code, + False, + log_id=log_id, + ) + + # update the rwc_inc_left column + rw_counter = rw_dict.rw_counter + for row in new_rows: + row.rwc_inc_left = rw_counter - row.rw_counter + self.rows.extend(new_rows) + return self + + def _append_row( + self, + rows: MutableSequence[CopyCircuitRow], + rw_dict: RWDictionary, + is_write: bool, + is_first: bool, + is_last: bool, + id: IntOrFQ, + tag: CopyDataTypeTag, + addr: IntOrFQ, + value: IntOrFQ, + is_code: IntOrFQ, + is_pad: bool, + src_addr_end: IntOrFQ = FQ(0), + bytes_left: IntOrFQ = FQ(0), + log_id: int = 0, + ): + is_memory = tag == CopyDataTypeTag.Memory + is_bytecode = tag == CopyDataTypeTag.Bytecode + is_tx_calldata = tag == CopyDataTypeTag.TxCalldata + is_tx_log = tag == CopyDataTypeTag.TxLog + rw_counter = rw_dict.rw_counter + if is_memory: + if is_write: + rw_dict.memory_write(id, addr, value) + else: + rw_dict.memory_read(id, addr, value) + elif is_tx_log: + assert is_write + rw_dict.tx_log_write(id, log_id, TxLogFieldTag.Data, addr, value) + addr += (int(TxLogFieldTag.Data) << 32) + (log_id << 48) + rows.append( + CopyCircuitRow( + q_step=FQ(not is_write), + is_first=FQ(is_first), + is_last=FQ(is_last), + id=FQ(id), + tag=FQ(tag), + addr=FQ(addr), + src_addr_end=FQ(src_addr_end), + bytes_left=FQ(bytes_left), + value=FQ(value), + is_code=FQ(is_code), + is_pad=FQ(is_pad), + rw_counter=FQ(rw_counter), + rwc_inc_left=FQ(0), # placeholder for now + is_memory=FQ(is_memory), + is_bytecode=FQ(is_bytecode), + is_tx_calldata=FQ(is_tx_calldata), + is_tx_log=FQ(is_tx_log), + ) + ) diff --git a/src/zkevm_specs/util/__init__.py b/src/zkevm_specs/util/__init__.py index 5f6e8aac5..5f70cb7eb 100644 --- a/src/zkevm_specs/util/__init__.py +++ b/src/zkevm_specs/util/__init__.py @@ -3,6 +3,7 @@ from Crypto.Random.random import randrange from .arithmetic import * +from .constraint_system import * from .hash import * from .param import * from .typing import * diff --git a/src/zkevm_specs/util/constraint_system.py b/src/zkevm_specs/util/constraint_system.py new file mode 100644 index 000000000..cc764641d --- /dev/null +++ b/src/zkevm_specs/util/constraint_system.py @@ -0,0 +1,50 @@ +from typing import Optional + +from .arithmetic import Expression, FQ + + +class ConstraintUnsatFailure(Exception): + def __init__(self, message: str) -> None: + self.message = message + + +class ConstraintSystem: + cond: Optional[Expression] + + def __init__(self, cond: Optional[Expression] = None): + self.cond = cond + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.cond = None + return self + + def _eval(self, expr: Expression): + if self.cond: + return self.cond.expr() * expr.expr() + return expr.expr() + + def constrain_equal(self, lhs: Expression, rhs: Expression): + assert self._eval(lhs.expr() - rhs.expr()) == 0, ConstraintUnsatFailure( + f"Expected values to be equal, but got {lhs} and {rhs}" + ) + + def constrain_zero(self, value: Expression): + assert self._eval(value) == 0, ConstraintUnsatFailure( + f"Expected value to be 0, but got {value}" + ) + + def constrain_bool(self, value: Expression): + assert self._eval(value) in [0, 1], ConstraintUnsatFailure( + f"Expected value to be a bool, but got {value}" + ) + + def is_zero(self, value: Expression) -> FQ: + return FQ(value.expr() == 0) + + def condition(self, cond: Expression): + assert self.cond is None, "Don't support recursive conditions" + self.cond = cond + return self diff --git a/src/zkevm_specs/util/param.py b/src/zkevm_specs/util/param.py index c273d061b..9dc37889f 100644 --- a/src/zkevm_specs/util/param.py +++ b/src/zkevm_specs/util/param.py @@ -45,6 +45,8 @@ GAS_COST_TX = 21000 # Constant gas cost of LOG GAS_COST_LOG = 375 +# Gas cost of per byte in a LOG* operation's data. +GAS_COST_LOGDATA = 8 # Gas cost of creation transaction GAS_COST_CREATION_TX = 53000 # Gas cost of transaction call_data per non-zero byte diff --git a/tests/evm/test_calldatacopy.py b/tests/evm/test_calldatacopy.py index 96e1025f4..a9284399d 100644 --- a/tests/evm/test_calldatacopy.py +++ b/tests/evm/test_calldatacopy.py @@ -5,7 +5,6 @@ Opcode, ExecutionState, StepState, - CopyToMemoryAuxData, verify_steps, Tables, RWTableTag, @@ -16,7 +15,10 @@ Transaction, Bytecode, RWDictionary, + CopyCircuit, + CopyDataTypeTag, ) +from zkevm_specs.copy_circuit import verify_copy_table from zkevm_specs.util import ( rand_fq, rand_bytes, @@ -24,6 +26,8 @@ MAX_N_BYTES_COPY_TO_MEMORY, MEMORY_EXPANSION_QUAD_DENOMINATOR, MEMORY_EXPANSION_LINEAR_COEFF, + memory_word_size, + memory_expansion, ) @@ -47,109 +51,6 @@ ) -def to_word_size(addr: int) -> int: - return (addr + 31) // 32 - - -def make_copy_step( - buffer_map: Mapping[int, int], - src_addr: int, - dst_addr: int, - src_addr_end: int, - bytes_left: int, - from_tx: bool, - rw_dictionary: RWDictionary, - program_counter: int, - stack_pointer: int, - memory_size: int, - gas_left: int, - code_hash: RLC, -) -> StepState: - aux_data = CopyToMemoryAuxData( - src_addr=src_addr, - dst_addr=dst_addr, - src_addr_end=src_addr_end, - bytes_left=bytes_left, - from_tx=from_tx, - src_id=TX_ID if from_tx else CALLER_ID, - ) - step = StepState( - execution_state=ExecutionState.CopyToMemory, - rw_counter=rw_dictionary.rw_counter, - call_id=1, - is_root=from_tx, - program_counter=program_counter, - stack_pointer=stack_pointer, - gas_left=gas_left, - memory_size=memory_size, - code_hash=code_hash, - aux_data=aux_data, - ) - - num_bytes = min(MAX_N_BYTES_COPY_TO_MEMORY, bytes_left) - for i in range(num_bytes): - byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 - if not from_tx and src_addr + i < src_addr_end: - rw_dictionary.memory_read(CALLER_ID, src_addr + i, byte) - rw_dictionary.memory_write(CALL_ID, dst_addr + i, byte) - - return step - - -def make_copy_steps( - buffer: bytes, - buffer_addr: int, - src_addr: int, - dst_addr: int, - length: int, - from_tx: bool, - rw_dictionary: RWDictionary, - program_counter: int, - stack_pointer: int, - memory_size: int, - gas_left: int, - code_hash: RLC, -) -> Sequence[StepState]: - buffer_addr_end = buffer_addr + len(buffer) - buffer_map = dict(zip(range(buffer_addr, buffer_addr_end), buffer)) - steps = [] - bytes_left = length - while bytes_left > 0: - new_step = make_copy_step( - buffer_map, - src_addr, - dst_addr, - buffer_addr_end, - bytes_left, - from_tx, - rw_dictionary, - program_counter, - stack_pointer, - memory_size, - gas_left, - code_hash, - ) - steps.append(new_step) - src_addr += MAX_N_BYTES_COPY_TO_MEMORY - dst_addr += MAX_N_BYTES_COPY_TO_MEMORY - bytes_left -= MAX_N_BYTES_COPY_TO_MEMORY - return steps - - -def memory_gas_cost(memory_word_size: int) -> int: - quad_cost = memory_word_size * memory_word_size // MEMORY_EXPANSION_QUAD_DENOMINATOR - linear_cost = memory_word_size * MEMORY_EXPANSION_LINEAR_COEFF - return quad_cost + linear_cost - - -def memory_copier_gas_cost( - curr_memory_word_size: int, next_memory_word_size: int, length: int -) -> int: - curr_memory_cost = memory_gas_cost(curr_memory_word_size) - next_memory_cost = memory_gas_cost(next_memory_word_size) - return to_word_size(length) * GAS_COST_COPY + next_memory_cost - curr_memory_cost - - @pytest.mark.parametrize( "call_data_length, data_offset, memory_offset, length, from_tx, call_data_offset", TESTING_DATA ) @@ -171,13 +72,15 @@ def test_calldatacopy( length_rlc = RLC(length, randomness) call_data = rand_bytes(call_data_length) - curr_memory_word_size = to_word_size(0 if from_tx else call_data_offset + call_data_length) - if length == 0: - next_memory_word_size = curr_memory_word_size - else: - next_memory_word_size = max(curr_memory_word_size, to_word_size(memory_offset + length)) - gas = Opcode.CALLDATACOPY.constant_gas_cost() + memory_copier_gas_cost( - curr_memory_word_size, next_memory_word_size, length + curr_mem_size = memory_word_size(0 if from_tx else call_data_offset + call_data_length) + address = 0 if length == 0 else memory_offset + length + next_mem_size, memory_gas_cost = memory_expansion( + curr_mem_size, memory_offset + length if length else 0 + ) + gas = ( + Opcode.CALLDATACOPY.constant_gas_cost() + + memory_gas_cost + + memory_word_size(length) * GAS_COST_COPY ) if from_tx: @@ -196,11 +99,18 @@ def test_calldatacopy( code_hash=bytecode_hash, program_counter=99, stack_pointer=1021, - memory_size=curr_memory_word_size, + memory_size=curr_mem_size, gas_left=gas, ) ] + rw_dictionary = ( + RWDictionary(1) + .stack_read(CALL_ID, 1021, memory_offset_rlc) + .stack_read(CALL_ID, 1022, data_offset_rlc) + .stack_read(CALL_ID, 1023, length_rlc) + .call_context_read(CALL_ID, CallContextFieldTag.TxId, TX_ID) + ) rw_dictionary = ( RWDictionary(1) .stack_read(CALL_ID, 1021, memory_offset_rlc) @@ -211,6 +121,8 @@ def test_calldatacopy( rw_dictionary.call_context_read(CALL_ID, CallContextFieldTag.TxId, TX_ID).call_context_read( CALL_ID, CallContextFieldTag.CallDataLength, call_data_length ) + src_data = dict(zip(range(call_data_length), call_data)) + assert call_data_offset == 0 else: rw_dictionary.call_context_read( CALL_ID, CallContextFieldTag.CallerId, CALLER_ID @@ -220,21 +132,24 @@ def test_calldatacopy( CALL_ID, CallContextFieldTag.CallDataOffset, call_data_offset ) - new_steps = make_copy_steps( - call_data, - call_data_offset, - call_data_offset + data_offset, + src_data = dict( + [ + (call_data_offset + i, call_data[i]) + for i in range(data_offset, min(data_offset + length, len(call_data))) + ] + ) + copy_circuit = CopyCircuit().copy( + rw_dictionary, + TX_ID if from_tx else CALLER_ID, + CopyDataTypeTag.TxCalldata if from_tx else CopyDataTypeTag.Memory, + CALL_ID, + CopyDataTypeTag.Memory, + data_offset + call_data_offset, + call_data_length + call_data_offset, memory_offset, length, - from_tx, - rw_dictionary=rw_dictionary, - program_counter=100, - memory_size=next_memory_word_size, - stack_pointer=1024, - gas_left=0, - code_hash=bytecode_hash, + src_data, ) - steps.extend(new_steps) steps.append( StepState( @@ -246,7 +161,7 @@ def test_calldatacopy( code_hash=bytecode_hash, program_counter=100, stack_pointer=1024, - memory_size=next_memory_word_size, + memory_size=next_mem_size, gas_left=0, ) ) @@ -256,8 +171,10 @@ def test_calldatacopy( tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set(rw_dictionary.rws), + copy_circuit=copy_circuit.rows, ) + verify_copy_table(copy_circuit, tables) verify_steps( randomness=randomness, tables=tables, diff --git a/tests/evm/test_codecopy.py b/tests/evm/test_codecopy.py index 5713df58f..48185c190 100644 --- a/tests/evm/test_codecopy.py +++ b/tests/evm/test_codecopy.py @@ -6,7 +6,6 @@ AccountFieldTag, Bytecode, CallContextFieldTag, - CopyCodeToMemoryAuxData, ExecutionState, Opcode, RW, @@ -15,7 +14,10 @@ StepState, Tables, verify_steps, + CopyCircuit, + CopyDataTypeTag, ) +from zkevm_specs.copy_circuit import verify_copy_table from zkevm_specs.util import ( GAS_COST_COPY, FQ, @@ -58,84 +60,6 @@ def memory_copier_gas_cost( return to_word_size(length) * GAS_COST_COPY + next_memory_cost - curr_memory_cost -def make_copy_code_step( - code: Bytecode, - code_hash: RLC, - buffer_map: Mapping[int, int], - src_addr: int, - dst_addr: int, - src_addr_end: int, - bytes_left: int, - rw_dictionary: RWDictionary, - program_counter: int, - stack_pointer: int, - memory_size: int, - randomness: FQ, -) -> StepState: - aux_data = CopyCodeToMemoryAuxData( - src_addr=src_addr, - dst_addr=dst_addr, - src_addr_end=src_addr_end, - bytes_left=bytes_left, - code_hash=RLC(code.hash(), randomness), - ) - step = StepState( - execution_state=ExecutionState.CopyCodeToMemory, - rw_counter=rw_dictionary.rw_counter, - call_id=CALL_ID, - is_root=True, - program_counter=program_counter, - stack_pointer=stack_pointer, - gas_left=0, - memory_size=memory_size, - code_hash=code_hash, - aux_data=aux_data, - ) - - num_bytes = min(MAX_N_BYTES_COPY_CODE_TO_MEMORY, bytes_left) - for i in range(num_bytes): - byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 - rw_dictionary.memory_write(CALL_ID, dst_addr + i, byte) - return step - - -def make_copy_code_steps( - code: Bytecode, - code_hash: RLC, - src_addr: int, - dst_addr: int, - length: int, - rw_dictionary: RWDictionary, - program_counter: int, - stack_pointer: int, - memory_size: int, - randomness: FQ, -) -> Sequence[StepState]: - buffer_map = dict(zip(range(src_addr, len(code.code)), code.code)) - steps = [] - bytes_left = length - while bytes_left > 0: - new_step = make_copy_code_step( - code, - code_hash, - buffer_map, - src_addr, - dst_addr, - len(code.code), - bytes_left, - rw_dictionary, - program_counter, - stack_pointer, - memory_size, - randomness, - ) - steps.append(new_step) - src_addr += MAX_N_BYTES_COPY_CODE_TO_MEMORY - dst_addr += MAX_N_BYTES_COPY_CODE_TO_MEMORY - bytes_left -= MAX_N_BYTES_COPY_CODE_TO_MEMORY - return steps - - @pytest.mark.parametrize("src_addr, dst_addr, length", TESTING_DATA) def test_codecopy(src_addr: U64, dst_addr: U64, length: U64): randomness = rand_fq() @@ -210,19 +134,19 @@ def test_codecopy(src_addr: U64, dst_addr: U64, length: U64): ), ] - steps_internal = make_copy_code_steps( - code, - code_hash, + src_data = dict([(i, (code.code[i], code.is_code[i])) for i in range(len(code.code))]) + copy_circuit = CopyCircuit().copy( + rw_dictionary, + code_hash.rlc_value, + CopyDataTypeTag.Bytecode, + CALL_ID, + CopyDataTypeTag.Memory, src_addr, + len(code.code), dst_addr, length, - rw_dictionary=rw_dictionary, - program_counter=100, - stack_pointer=1024, - memory_size=next_memory_word_size, - randomness=randomness, + src_data, ) - steps.extend(steps_internal) # rw counter post memory writes rw_counter_final = rw_dictionary.rw_counter @@ -247,74 +171,10 @@ def test_codecopy(src_addr: U64, dst_addr: U64, length: U64): tx_table=set(), bytecode_table=set(code.table_assignments(randomness)), rw_table=set(rw_dictionary.rws), + copy_circuit=copy_circuit.rows, ) - verify_steps( - randomness=randomness, - tables=tables, - steps=steps, - ) - - -@pytest.mark.parametrize("src_addr, dst_addr, length", TESTING_DATA) -def test_copy_code_to_memory(src_addr: U64, dst_addr: U64, length: U64): - randomness = rand_fq() - - code = ( - Bytecode() - .push32(0x123) - .pop() - .push32(0x213) - .pop() - .push32(0x321) - .pop() - .push32(0x12349AB) - .pop() - .push32(0x1928835) - .pop() - ) - - dummy_code = Bytecode().stop() - code_hash = RLC(dummy_code.hash(), randomness) - - rw_dictionary = RWDictionary(1) - - next_memory_word_size = to_word_size(dst_addr + length) - steps = make_copy_code_steps( - code, - code_hash, - src_addr, - dst_addr, - length, - rw_dictionary=rw_dictionary, - program_counter=0, - memory_size=next_memory_word_size, - stack_pointer=1024, - randomness=randomness, - ) - steps.append( - StepState( - execution_state=ExecutionState.STOP, - rw_counter=rw_dictionary.rw_counter, - call_id=CALL_ID, - is_root=True, - is_create=False, - code_hash=code_hash, - program_counter=0, - stack_pointer=1024, - memory_size=next_memory_word_size, - gas_left=0, - ) - ) - - tables = Tables( - block_table=set(), - tx_table=set(), - bytecode_table=set(code.table_assignments(randomness)).union( - dummy_code.table_assignments(randomness) - ), - rw_table=set(rw_dictionary.rws), - ) + verify_copy_table(copy_circuit, tables) verify_steps( randomness=randomness, diff --git a/tests/evm/test_logs.py b/tests/evm/test_logs.py index fc464c723..e5d7daaf8 100644 --- a/tests/evm/test_logs.py +++ b/tests/evm/test_logs.py @@ -4,7 +4,6 @@ Opcode, ExecutionState, StepState, - CopyToLogAuxData, verify_steps, Tables, RWTableTag, @@ -16,9 +15,12 @@ Transaction, Bytecode, GAS_COST_LOG, + GAS_COST_LOGDATA, RWDictionary, + CopyCircuit, + CopyDataTypeTag, ) -from zkevm_specs.evm.execution.copy_to_log import MAX_COPY_BYTES +from zkevm_specs.copy_circuit import verify_copy_table from zkevm_specs.util import ( rand_fq, rand_bytes, @@ -32,139 +34,150 @@ CALL_ID = 1 TX_ID = 2 CALLEE_ADDRESS = rand_address() -bytecodes = [ - Bytecode().log0(), - Bytecode().log1(), - Bytecode().log2(), - Bytecode().log3(), - Bytecode().log4(), -] -TESTING_DATA = ( + +SINGLE_LOG_TESTING_DATA = ( # is_persistent = true cases # zero topic(log0) - ([], 10, 2, 1), + ([], 10, 2, True), # one topic(log1) - ([0x030201], 20, 3, 1), + ([0x030201], 20, 3, True), # two topics(log2) - ([0x030201, 0x0F0E0D], 100, 20, 1), + ([0x030201, 0x0F0E0D], 100, 20, True), # three topics(log3) - ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, 1), + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), # four topics(log4) - ([0x030201, 0x0F0E0D, 0x0D8F01, 0x0AA213], 421, 15, 1), + ([0x030201, 0x0F0E0D, 0x0D8F01, 0x0AA213], 421, 15, True), # is_persistent = false cases # zero topic(log0) - ([], 10, 2, 0), + ([], 10, 2, False), # one topic(log1) - ([0x030201], 20, 3, 0), + ([0x030201], 20, 3, False), # two topics(log2) - ([0x030201, 0x0F0E0D], 100, 20, 0), + ([0x030201, 0x0F0E0D], 100, 20, False), # three topics(log3) - ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, 0), + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, False), # four topics(log4) - ([0x030201, 0x0F0E0D, 0x0D8F01, 0x0AA213], 421, 15, 0), + ([0x030201, 0x0F0E0D, 0x0D8F01, 0x0AA213], 421, 15, False), +) + +MULTI_LOGS_TESTING_DATA = ( + ( + ([], 10, 2, True), + ([0x030201, 0x0F0E0D], 100, 20, True), + ), + ( + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), + ([0x030201], 20, 3, False), + ), + ( + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), + ([0x030201], 20, 3, False), + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), + ), + ( + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), + ([0x030201], 20, 3, True), + ([0x030201, 0x0F0E0D, 0x0D8F01], 180, 50, True), + ), ) -def make_log_copy_step( - buffer_map: Mapping[int, int], - src_addr: int, - src_addr_end: int, - bytes_left: int, - data_start_index: int, +def log_code(bytecode: Bytecode, num_topic: int): + if num_topic == 0: + bytecode.log0() + elif num_topic == 1: + bytecode.log1() + elif num_topic == 2: + bytecode.log2() + elif num_topic == 3: + bytecode.log3() + elif num_topic == 4: + bytecode.log4() + else: + raise ValueError(f"Incorrect number of topics: {num_topic}") + + +# helper to construct topics rows of RW table +def construct_topic_rws( rw_dictionary: RWDictionary, - program_counter: int, - stack_pointer: int, - memory_size: int, - gas_left: int, - code_hash: RLC, log_id: int, + sp: int, + topics: list, is_persistent: bool, -) -> Tuple[StepState, Sequence[RW]]: - aux_data = CopyToLogAuxData( - src_addr=src_addr, - src_addr_end=src_addr_end, - bytes_left=bytes_left, - is_persistent=is_persistent, - tx_id=TX_ID, - data_start_index=data_start_index, - ) - step = StepState( - execution_state=ExecutionState.CopyToLog, - rw_counter=rw_dictionary.rw_counter, - call_id=CALL_ID, - program_counter=program_counter, - stack_pointer=stack_pointer, - gas_left=gas_left, - memory_size=memory_size, - code_hash=code_hash, - log_id=is_persistent, - aux_data=aux_data, - ) - num_bytes = min(MAX_COPY_BYTES, bytes_left) - for i in range(num_bytes): - byte = buffer_map[src_addr + i] if src_addr + i < src_addr_end else 0 - if src_addr + i < src_addr_end: - rw_dictionary.memory_read(CALL_ID, src_addr + i, FQ(byte)) - if is_persistent: - rw_dictionary.tx_log_write( - TX_ID, log_id, TxLogFieldTag.Data, i + data_start_index, FQ(byte) - ) + randomness: int, +): + for i in range(len(topics)): + rw_dictionary.stack_read(CALL_ID, sp, RLC(topics[i], randomness, 32)) + if is_persistent: + rw_dictionary.tx_log_write( + TX_ID, log_id, TxLogFieldTag.Topic, i, RLC(topics[i], randomness, 32) + ) - return step + sp += 1 -def make_log_copy_steps( - buffer: bytes, - buffer_addr: int, - src_addr: int, - length: int, +def make_log( rw_dictionary: RWDictionary, - program_counter: int, + copy_circuit: CopyCircuit, + randomness: FQ, stack_pointer: int, - memory_size: int, - gas_left: int, - code_hash: RLC, log_id: int, + topics: list, + mstart: U64, + msize: U64, is_persistent: bool, -) -> Sequence[StepState]: - buffer_addr_end = buffer_addr + len(buffer) - buffer_map = dict(zip(range(buffer_addr, buffer_addr_end), buffer)) - steps = [] - bytes_left = length - data_start_index = 0 - while bytes_left > 0: - new_step = make_log_copy_step( - buffer_map, - src_addr, - buffer_addr_end, - bytes_left, - data_start_index, +): + data = rand_bytes(msize) + ( + rw_dictionary.stack_read(CALL_ID, stack_pointer, RLC(mstart, randomness)) + .stack_read(CALL_ID, stack_pointer + 1, RLC(msize, randomness)) + .call_context_read(CALL_ID, CallContextFieldTag.TxId, TX_ID) + .call_context_read(CALL_ID, CallContextFieldTag.IsStatic, 0) + .call_context_read(CALL_ID, CallContextFieldTag.CalleeAddress, FQ(CALLEE_ADDRESS)) + .call_context_read(CALL_ID, CallContextFieldTag.IsPersistent, is_persistent) + ) + + if is_persistent: + rw_dictionary.tx_log_write(TX_ID, log_id, TxLogFieldTag.Address, 0, FQ(CALLEE_ADDRESS)) + + # append topic rows + construct_topic_rws(rw_dictionary, log_id, stack_pointer + 2, topics, is_persistent, randomness) + + # copy the log data + src_data = dict([(mstart + i, byte) for (i, byte) in enumerate(data)]) + if is_persistent: + copy_circuit.copy( rw_dictionary, - program_counter, - stack_pointer, - memory_size, - gas_left, - code_hash, - log_id, - is_persistent, + CALL_ID, + CopyDataTypeTag.Memory, + TX_ID, + CopyDataTypeTag.TxLog, + mstart, + mstart + msize, + 0, + msize, + src_data, + log_id=log_id, ) - steps.append(new_step) - src_addr += MAX_COPY_BYTES - data_start_index += MAX_COPY_BYTES - bytes_left -= MAX_COPY_BYTES - return steps + return stack_pointer + 2 + len(topics) -@pytest.mark.parametrize("topics, mstart, msize, is_persistent", TESTING_DATA) -def test_logs(topics: list, mstart: U64, msize: U64, is_persistent: bool): +@pytest.mark.parametrize("topics, mstart, msize, is_persistent", SINGLE_LOG_TESTING_DATA) +def test_single_log(topics: list, mstart: U64, msize: U64, is_persistent: bool): randomness = rand_fq() - data = rand_bytes(msize) - topic_count = len(topics) - next_memory_size, memory_expansion_cost = memory_expansion(mstart, msize) - dynamic_gas = GAS_COST_LOG + GAS_COST_LOG * topic_count + 8 * msize + memory_expansion_cost - bytecode = bytecodes[topic_count] + # init bytecode + bytecode = Bytecode() + log_code(bytecode, len(topics)) + bytecode.stop() bytecode_hash = RLC(bytecode.hash(), randomness) - tx = Transaction(id=TX_ID, gas=dynamic_gas) + + rw_dictionary = RWDictionary(1) + copy_circuit = CopyCircuit() + + next_memory_size, memory_expansion_cost = memory_expansion(0, mstart + msize) + dynamic_gas = ( + GAS_COST_LOG + GAS_COST_LOG * len(topics) + GAS_COST_LOGDATA * msize + memory_expansion_cost + ) steps = [ StepState( execution_state=ExecutionState.LOG, @@ -175,43 +188,15 @@ def test_logs(topics: list, mstart: U64, msize: U64, is_persistent: bool): code_hash=bytecode_hash, program_counter=0, stack_pointer=1015, - memory_size=mstart, + memory_size=0, gas_left=dynamic_gas, log_id=0, ) ] - - rw_dictionary = ( - RWDictionary(1) - .stack_read(CALL_ID, 1015, RLC(mstart, randomness)) - .stack_read(CALL_ID, 1016, RLC(msize, randomness)) - .call_context_read(CALL_ID, CallContextFieldTag.TxId, TX_ID) - .call_context_read(CALL_ID, CallContextFieldTag.IsStatic, FQ(0)) - .call_context_read(CALL_ID, CallContextFieldTag.CalleeAddress, FQ(CALLEE_ADDRESS)) - .call_context_read(CALL_ID, CallContextFieldTag.IsPersistent, is_persistent) + sp = make_log( + rw_dictionary, copy_circuit, randomness, 1015, 1, topics, mstart, msize, is_persistent ) - if is_persistent: - rw_dictionary.tx_log_write(TX_ID, 1, TxLogFieldTag.Address, 0, FQ(CALLEE_ADDRESS)) - - # append topic rows - construct_topic_rws(rw_dictionary, 1017, topics, is_persistent, randomness) - new_steps = make_log_copy_steps( - data, - mstart, - mstart, - msize, - rw_dictionary=rw_dictionary, - program_counter=1, - memory_size=next_memory_size, - stack_pointer=1015 + (2 + topic_count), - gas_left=0, - code_hash=bytecode_hash, - log_id=1, - is_persistent=is_persistent, - ) - # append memory & log steps and rows - steps.extend(new_steps) steps.append( StepState( execution_state=ExecutionState.STOP, @@ -221,18 +206,22 @@ def test_logs(topics: list, mstart: U64, msize: U64, is_persistent: bool): is_create=False, code_hash=bytecode_hash, program_counter=1, - stack_pointer=1015 + (2 + topic_count), + stack_pointer=sp, memory_size=next_memory_size, gas_left=0, log_id=is_persistent, ) ) + + tx = Transaction(id=TX_ID, gas=dynamic_gas) tables = Tables( block_table=set(Block().table_assignments(randomness)), tx_table=set(tx.table_assignments(randomness)), bytecode_table=set(bytecode.table_assignments(randomness)), rw_table=set(rw_dictionary.rws), + copy_circuit=copy_circuit.rows, ) + verify_copy_table(copy_circuit, tables) verify_steps( randomness=randomness, tables=tables, @@ -240,19 +229,83 @@ def test_logs(topics: list, mstart: U64, msize: U64, is_persistent: bool): ) -# helper to construct topics rows of RW table -def construct_topic_rws( - rw_dictionary: RWDictionary, - sp: int, - topics: list, - is_persistent: bool, - randomness: int, -): - for i in range(len(topics)): - rw_dictionary.stack_read(CALL_ID, sp, RLC(topics[i], randomness, 32)) - if is_persistent: - rw_dictionary.tx_log_write( - TX_ID, 1, TxLogFieldTag.Topic, i, RLC(topics[i], randomness, 32) +@pytest.mark.parametrize("log_entries", MULTI_LOGS_TESTING_DATA) +def test_multi_logs(log_entries): + randomness = rand_fq() + # init bytecode + bytecode = Bytecode() + total_gas = 0 + for topics, _, msize, _ in log_entries: + log_code(bytecode, len(topics)) + total_gas += GAS_COST_LOG + GAS_COST_LOG * len(topics) + GAS_COST_LOGDATA * msize + bytecode.stop() + bytecode_hash = RLC(bytecode.hash(), randomness) + + tx = Transaction(id=TX_ID, gas=total_gas) + steps = [] + rw_dictionary = RWDictionary(1) + copy_circuit = CopyCircuit() + + stack_pointer = 1000 + log_id = 0 + gas_left = total_gas + for pc, (topics, mstart, msize, is_persistent) in enumerate(log_entries): + steps.append( + StepState( + execution_state=ExecutionState.LOG, + rw_counter=rw_dictionary.rw_counter, + call_id=CALL_ID, + is_root=False, + is_create=False, + code_hash=bytecode_hash, + program_counter=pc, + stack_pointer=stack_pointer, + memory_size=50, + gas_left=gas_left, + log_id=log_id, ) + ) + stack_pointer = make_log( + rw_dictionary, + copy_circuit, + randomness, + stack_pointer, + log_id + 1, + topics, + mstart, + msize, + is_persistent, + ) + log_id += is_persistent + gas_left -= GAS_COST_LOG + GAS_COST_LOG * len(topics) + GAS_COST_LOGDATA * msize - sp += 1 + steps.append( + StepState( + execution_state=ExecutionState.STOP, + rw_counter=rw_dictionary.rw_counter, + call_id=CALL_ID, + is_root=False, + is_create=False, + code_hash=bytecode_hash, + program_counter=len(log_entries), + stack_pointer=stack_pointer, + memory_size=50, + gas_left=0, + log_id=log_id, + ) + ) + + tables = Tables( + block_table=set(Block().table_assignments(randomness)), + tx_table=set(tx.table_assignments(randomness)), + bytecode_table=set(bytecode.table_assignments(randomness)), + rw_table=set(rw_dictionary.rws), + copy_circuit=copy_circuit.rows, + ) + + verify_copy_table(copy_circuit, tables) + verify_steps( + randomness=randomness, + tables=tables, + steps=steps, + )