diff --git a/specs/tables.md b/specs/tables.md index b03963fb8..3b51f4a01 100644 --- a/specs/tables.md +++ b/specs/tables.md @@ -44,14 +44,15 @@ Details: - **Address (key2)** is reserved for stack, memory, and account addresses. - **StorageKey (key4)** is reserved for RLC encoded values -- **value, valuePrev**: variable size, depending on Tag (key0) and FieldTag (key3) where appropriate. +- **value, intialValue**: variable size, depending on Tag (key0) and FieldTag (key3) where appropriate. +- **root**: RLC encoded MPT state root. - **(rw) counter**: 32 bits, starts at 1. - **txID**: 32 bits, starts at 1 (corresponds to `txIndex + 1`). - **address**: 160 bits - **callID**: 32 bits, starts at 1 (corresponds to `rw_counter` when the call begins). - **Stack -> stackPointer**: 10 bits - **Memory -> memoryAddress**: 32 bits -- **Memory -> value, valuePrev**: 1 byte +- **Memory -> value**: 1 byte - **storageKey**: field size, RLC encoded (Random Linear Combination). - **TxLog Address column**: Packs 2 values: - **TxLog -> logID**: 32 bits, starts at 1 (corresponds to `logIndex + 1`), it is unique per tx/receipt. @@ -63,65 +64,65 @@ Details: NOTE: `kN` means `keyN` -| 0 *Rwc* | 1 *IsWrite* | 2 *Tag* (k0) | 3 *Id* (k1) | 4 *Address* (k2) | 5 *FieldTag* (k3) | 6 *StorageKey* (k4) | 7 *Value0* | 8 *Value1* | 9 *Aux0* | -| -------- | ----------- | -------------------------- | -------- | -------- | -------------------------- | ----------- | --------- | ---------- | --------------- | -| | | *RwTableTag* | | | | | | | | -| $counter | true | TxAccessListAccount | $txID | $address | | | $value | $valuePrev | 0 | -| $counter | true | TxAccessListAccountStorage | $txID | $address | | $storageKey | $value | $valuePrev | 0 | -| $counter | $isWrite | TxRefund | $txID | | | | $value | $valuePrev | 0 | -| | | | | | | | | | | -| | | | | | *AccountFieldTag* | | | | | -| $counter | $isWrite | Account | | $address | Nonce | | $value | $valuePrev | $committedValue | -| $counter | $isWrite | Account | | $address | Balance | | $value | $valuePrev | $committedValue | -| $counter | $isWrite | Account | | $address | CodeHash | | $value | $valuePrev | $committedValue | -| $counter | true | AccountDestructed | | $address | | | $value | $valuePrev | 0 | -| | | | | | | | | | | -| | | *CallContext constant* | | | *CallContextFieldTag* (ro) | | | | | -| $counter | false | CallContext | $callID | | RwCounterEndOfReversion | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | CallerId | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | TxId | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | Depth | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | CallerAddress | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | CalleeAddress | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | CallDataOffset | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | CallDataLength | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | ReturnDataOffset | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | ReturnDataLength | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | Value | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | IsSuccess | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | IsPersistent | | $value | 0 | 0 | -| $counter | false | CallContext | $callID | | IsStatic | | $value | 0 | 0 | +| 0 *Rwc* | 1 *IsWrite* | 2 *Tag* (k0) | 3 *Id* (k1) | 4 *Address* (k2) | 5 *FieldTag* (k3) | 6 *StorageKey* (k4) | 7 *Value* | 8 *InitialValue* | 9 *Root* | +| -------- | ----------- | -------------------------- | -------- | -------- | -------------------------- | ----------- | --------- | ---------------- | -------- | +| | | *RwTableTag* | | | | | | | | +| $counter | true | TxAccessListAccount | $txID | $address | | | $value | 0 | $root | +| $counter | true | TxAccessListAccountStorage | $txID | $address | | $storageKey | $value | 0 | $root | +| $counter | $isWrite | TxRefund | $txID | | | | $value | 0 | $root | +| | | | | | | | | | | +| | | | | | *AccountFieldTag* | | | | | +| $counter | $isWrite | Account | | $address | Nonce | | $value | $committedValue | $root | +| $counter | $isWrite | Account | | $address | Balance | | $value | $committedValue | $root | +| $counter | $isWrite | Account | | $address | CodeHash | | $value | $committedValue | $root | +| $counter | true | AccountDestructed | | $address | | | $value | 0 | $root | +| | | | | | | | | | | +| | | *CallContext constant* | | | *CallContextFieldTag* (ro) | | | | | +| $counter | false | CallContext | $callID | | RwCounterEndOfReversion | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | CallerId | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | TxId | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | Depth | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | CallerAddress | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | CalleeAddress | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | CallDataOffset | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | CallDataLength | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | ReturnDataOffset | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | ReturnDataLength | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | Value | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | IsSuccess | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | IsPersistent | | $value | 0 | $root | +| $counter | false | CallContext | $callID | | IsStatic | | $value | 0 | $root | | | | | | | | | | | | | | | *CallContext last callee* | | | *CallContextFieldTag* (rw) | | | | | -| $counter | $isWrite | CallContext | $callID | | LastCalleeId | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataOffset | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataLength | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | LastCalleeId | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataOffset | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | LastCalleeReturnDataLength | | $value | 0 | $root | | | | | | | | | | | | | | | *CallContext state* | | | *CallContextFieldTag* (rw) | | | | | -| $counter | $isWrite | CallContext | $callID | | IsRoot | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | IsCreate | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | CodeHash | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | ProgramCounter | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | StackPointer | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | GasLeft | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | MemorySize | | $value | 0 | 0 | -| $counter | $isWrite | CallContext | $callID | | ReversibleWriteCounter | | $value | 0 | 0 | +| $counter | $isWrite | CallContext | $callID | | IsRoot | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | IsCreate | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | CodeHash | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | ProgramCounter | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | StackPointer | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | GasLeft | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | MemorySize | | $value | 0 | $root | +| $counter | $isWrite | CallContext | $callID | | ReversibleWriteCounter | | $value | 0 | $root | | | | | | | | | | | | -| $counter | $isWrite | Stack | $callID | $stackPointer | | | $value | 0 | 0 | -| $counter | $isWrite | Memory | $callID | $memoryAddress | | | $value | 0 | 0 | -| $counter | $isWrite | AccountStorage | $txID | $address | | $storageKey | $value | $valuePrev | $committedValue | +| $counter | $isWrite | Stack | $callID | $stackPointer | | | $value | 0 | $root | +| $counter | $isWrite | Memory | $callID | $memoryAddress | | | $value | 0 | $root | +| $counter | $isWrite | AccountStorage | $txID | $address | | $storageKey | $value | $committedValue | | | | | | | | | | | | | | | | | | *TxLogTag* | | | | | -| $counter | true | TxLog | $txID | $logID,0 | Address | 0 | $value | 0 | 0 | -| $counter | true | TxLog | $txID | $logID,$topicIndex | Topic | 0 | $value | 0 | 0 | -| $counter | true | TxLog | $txID | $logID,$byteIndex | Data | 0 | $value | 0 | 0 | -| $counter | true | TxLog | $txID | $logID,0 | TopicLength | 0 | $value | 0 | 0 | -| $counter | true | TxLog | $txID | $logID,0 | DataLength | 0 | $value | 0 | 0 | -| | | | | | | | | | | -| | | | | | *TxReceiptTag* | | | | | -| $counter | false | TxReceipt | $txID | 0 | PostStateOrStatus | 0 | $value | 0 | 0 | -| $counter | false | TxReceipt | $txID | 0 | CumulativeGasUsed | 0 | $value | 0 | 0 | -| $counter | false | TxReceipt | $txID | 0 | LogLength | 0 | $value | 0 | 0 | +| $counter | true | TxLog | $txID | $logID,0 | Address | 0 | $value | 0 | $root | +| $counter | true | TxLog | $txID | $logID,$topicIndex | Topic | 0 | $value | 0 | $root | +| $counter | true | TxLog | $txID | $logID,$byteIndex | Data | 0 | $value | 0 | $root | +| $counter | true | TxLog | $txID | $logID,0 | TopicLength | 0 | $value | 0 | $root | +| $counter | true | TxLog | $txID | $logID,0 | DataLength | 0 | $value | 0 | $root | +| | | | | | | | | | | +| | | | | | *TxReceiptTag* | | | | | +| $counter | false | TxReceipt | $txID | 0 | PostStateOrStatus | 0 | $value | 0 | $root | +| $counter | false | TxReceipt | $txID | 0 | CumulativeGasUsed | 0 | $value | 0 | $root | +| $counter | false | TxReceipt | $txID | 0 | LogLength | 0 | $value | 0 | $root | ## `bytecode_table` @@ -198,9 +199,21 @@ Provided by the MPT (Merkle Patricia Trie) circuit. The current MPT circuit design exposes one big table where different targets require different lookups as described below. From this table, the following columns contain values using the RLC encoding: - Address +- FieldTag - Key - ValuePrev -- ValueCur +- Value +- RootPrev +- Root + +The circuit can prove that updates to account nonces, balances, or storage slots are correct, or that an account's code hash is some particular value. Note that it is not possible to change the code hash for an account without deleting it and then recreating it. + +| Address | FieldTag | Key | ValuePrev | Value | RootPrev | Root | +| - | - | - | - | - | - | - | +| $addr | Nonce | 0 | $noncePrev | $nonceCur | $rootPrev | $root | +| $addr | Balance | 0 | $balancePrev | $balanceCur | $rootPrev | $root | +| $addr | CodeHash | 0 |$codeHash | $codeHash | $rootPrev | $root | +| $addr | Storage | $key | $valuePrev | $value | $rootPrev | $root | ## Keccak Table @@ -208,7 +221,7 @@ See [tx.py](src/zkevm_specs/tx.py) | IsEnabled | InputRLC | InputLen | Output | | --------- | ---------- | -------- | ----------- | -| bool | $input_rlc | $input_length | $output_rlc | +| bool | $input_rlc | $input_length | $output_rlc | Column names in circuit: - IsEnabled: `is_final` diff --git a/src/zkevm_specs/evm/table.py b/src/zkevm_specs/evm/table.py index 567e7713e..fbbe55974 100644 --- a/src/zkevm_specs/evm/table.py +++ b/src/zkevm_specs/evm/table.py @@ -157,17 +157,6 @@ class RW(IntEnum): Write = 1 -class MPTTableTag(IntEnum): - """ - Tag for MPTTable lookup - """ - - Nonce = 1 - Balance = 2 - CodeHash = 4 - Storage = 8 - - class RWTableTag(IntEnum): """ Tag for RWTable lookup, where the RWTable an advice-column table built by @@ -390,10 +379,11 @@ class RWTableRow(TableRow): @dataclass(frozen=True) class MPTTableRow(TableRow): - counter: Expression - target: Expression # MPTTableTag address: Expression - key: Expression + field_tag: Expression + storage_key: Expression + root: Expression + root_prev: Expression value: Expression value_prev: Expression diff --git a/src/zkevm_specs/state.py b/src/zkevm_specs/state.py index 892404fee..7735320a6 100644 --- a/src/zkevm_specs/state.py +++ b/src/zkevm_specs/state.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Tuple, List, Set, Union, cast +from typing import NamedTuple, Tuple, List, Set, Dict, Optional from enum import IntEnum from math import log, ceil @@ -11,7 +11,6 @@ TxLogFieldTag, TxReceiptFieldTag, MPTTableRow, - MPTTableTag, lookup, ) @@ -78,8 +77,10 @@ class Row(NamedTuple): FQ,FQ,FQ,FQ,FQ,FQ,FQ,FQ, FQ,FQ,FQ,FQ,FQ,FQ,FQ,FQ] value: FQ - auxs: Tuple[FQ] - mpt_counter: FQ + committed_value: FQ + + root: FQ + # fmt: on def tag(self): @@ -114,39 +115,24 @@ class Tables: def __init__(self, mpt_table: Set[MPTTableRow]): self.mpt_table = mpt_table - def mpt_account_lookup( - self, - counter: Expression, - target: Expression, - address: Expression, - value: Expression, - value_prev: Expression, - ) -> MPTTableRow: - query = { - "counter": counter, - "target": target, - "address": address, - "key": FQ(0), - "value": value, - "value_prev": value_prev, - } - return lookup(MPTTableRow, self.mpt_table, query) - - def mpt_storage_lookup( + def mpt_lookup( self, - counter: Expression, address: Expression, - key: Expression, + field_tag: Expression, + storage_key: Expression, value: Expression, value_prev: Expression, + root: Expression, + root_prev: Expression, ) -> MPTTableRow: query = { - "counter": counter, - "target": FQ(MPTTableTag.Storage), "address": address, - "key": key, + "field_tag": field_tag, + "storage_key": storage_key, "value": value, "value_prev": value_prev, + "root": root, + "root_prev": root_prev, } return lookup(MPTTableRow, self.mpt_table, query) @@ -184,9 +170,6 @@ def check_start(row: Row, row_prev: Row): # 1.0. rw_counter is 0 assert row.rw_counter == 0 - # 1. mpt_counter is 0 - assert row.mpt_counter == 0 - @is_circuit_code def check_memory(row: Row, row_prev: Row): @@ -210,6 +193,9 @@ def check_memory(row: Row, row_prev: Row): # 2.3. value is a byte assert_in_range(row.value, 0, 2**8 - 1) + # 2.4 state root does not change + assert row.root == row_prev.root + @is_circuit_code def check_stack(row: Row, row_prev: Row): @@ -239,33 +225,28 @@ def check_stack(row: Row, row_prev: Row): stack_ptr_diff = get_stack_ptr(row) - get_stack_ptr(row_prev) assert_in_range(stack_ptr_diff, 0, 1) + # 3.4 state root does not change + assert row.root == row_prev.root -@is_circuit_code -def check_storage(row: Row, row_prev: Row, tables: Tables): - get_addr = lambda row: row.address() - get_storage_key = lambda row: row.storage_key() - get_committed_value = lambda row: row.auxs[0] +@is_circuit_code +def check_storage(row: Row, row_prev: Row, row_next: Row, tables: Tables): # 4.0. Unused keys are 0 assert row.field_tag() == 0 - # 4.1. When keys don't change, committed_value must be kept equal - if all_keys_eq(row, row_prev): - assert get_committed_value(row) == get_committed_value(row_prev) - - # TODO: The current spec does an MPT lookup for every storage update. The - # next optimization consists on doing a single lookup merging all updates - # for a given key, using the first and last access values. - - # 4.2. MPT storage lookup with incremental counter - # - # When the keys are equal in the previous row, the value_prev must be the - # value in previous row. When the keys change, value_prev is loaded from - # committed_value, which holds the storage value before the tx began. - value_prev = row_prev.value if all_keys_eq(row, row_prev) else get_committed_value(row) - tables.mpt_storage_lookup( - row.mpt_counter, get_addr(row), get_storage_key(row), row.value, value_prev - ) + # 4.1. MPT lookup for last access to (address, storage_key) + if not all_keys_eq(row, row_next): + tables.mpt_lookup( + row.address(), + row.field_tag(), + row.storage_key(), + row.value, + row.committed_value, + row.root, + row_prev.root, + ) + else: + assert row.root == row_prev.root @is_circuit_code @@ -277,36 +258,34 @@ def check_call_context(row: Row, row_prev: Row): assert row.address() == 0 assert row.storage_key() == 0 + # 5.1 state root does not change + assert row.root == row_prev.root + # TODO: Missing constraints @is_circuit_code -def check_account(row: Row, row_prev: Row, tables: Tables): +def check_account(row: Row, row_prev: Row, row_next: Row, tables: Tables): get_addr = lambda row: row.address() get_field_tag = lambda row: row.field_tag() - get_committed_value = lambda row: row.auxs[0] # 6.0. Unused keys are 0 assert row.id() == 0 assert row.storage_key() == 0 - # 6.1. When keys don't change, committed_value must be kept equal - if all_keys_eq(row, row_prev): - assert get_committed_value(row) == get_committed_value(row_prev) - - # TODO: The current spec does an MPT lookup for every storage update. The - # next optimization consists on doing a single lookup merging all updates - # for a given key, using the first and last access values. - - # 6.2. MPT storage lookup with incremental counter - # - # When the keys are equal in the previous row, the value_prev must be the - # value in previous row. When the keys change, value_prev is loaded from - # committed_value, which holds the account value before the block began. - value_prev = row_prev.value if all_keys_eq(row, row_prev) else get_committed_value(row) - tables.mpt_account_lookup( - row.mpt_counter, get_field_tag(row), get_addr(row), row.value, value_prev - ) + # 6.2. MPT storage lookup for last access to (address, field_tag) + if not all_keys_eq(row, row_next): + tables.mpt_lookup( + get_addr(row), + get_field_tag(row), + row.storage_key(), + row.value, + row.committed_value, + row.root, + row_prev.root, + ) + else: + assert row.root == row_prev.root # NOTE: Value transition rules are constrained via the EVM circuit: for example, # Nonce only increases by 1 or decreases by 1 (on revert). @@ -321,6 +300,9 @@ def check_tx_refund(row: Row, row_prev: Row): assert row.field_tag() == 0 assert row.storage_key() == 0 + # 7.1 state root does not change + assert row.root == row_prev.root + # TODO: Missing constraints # - When keys change, value must be 0 @@ -334,6 +316,9 @@ def check_tx_access_list_account(row: Row, row_prev: Row): assert row.field_tag() == 0 assert row.storage_key() == 0 + # 9.1 state root does not change + assert row.root == row_prev.root + # TODO: Missing constraints # - When keys change, value must be 0 @@ -347,7 +332,10 @@ def check_tx_access_list_account_storage(row: Row, row_prev: Row): # 8.0. Unused keys are 0 assert row.field_tag() == 0 - # TODO: Missing constraints + # 8.1 State root cannot change + assert row.root == row_prev.root + + # TODO: state root does not change # - When keys change, value must be 0 @@ -363,6 +351,8 @@ def check_account_destructed(row: Row, row_prev: Row): # TODO: Missing constraints # - When keys change, value must be 0 + # TODO: add MPT lookup + @is_circuit_code def check_tx_log(row: Row, row_prev: Row): @@ -373,6 +363,9 @@ def check_tx_log(row: Row, row_prev: Row): # 12.0 is_write is always true assert row.is_write == 1 + # 12.1 state root does not change + assert row.root == row_prev.root + # removed field_tag-specific constraints as issue # https://github.com/privacy-scaling-explorations/zkevm-specs/issues/221 @@ -403,9 +396,12 @@ def check_tx_receipt(row: Row, row_prev: Row): assert_in_range(tx_id, 1, 2**11) + # 11.4 state root does not change + assert row.root == row_prev.root + @is_circuit_code -def check_state_row(row: Row, row_prev: Row, tables: Tables, randomness: FQ): +def check_state_row(row: Row, row_prev: Row, row_next: Row, tables: Tables, randomness: FQ): # # Constraints that affect all rows, no matter which Tag they use # @@ -486,15 +482,8 @@ def keys_rwc_to_limbs_in_order(row: Row) -> List[FQ]: if row.is_write == 0 and all_keys_eq(row, row_prev): assert row.value == row_prev.value - # 7. Increment mpt_counter - # - # When row is Storage or Account, increment the mpt_counter by - # one, otherwise maintain the same value - if row.tag() != Tag.Start: - if row.tag() == Tag.Storage or row.tag() == Tag.Account: - assert row.mpt_counter == row_prev.mpt_counter + 1 - else: - assert row.mpt_counter == row_prev.mpt_counter + if all_keys_eq(row, row_prev): + assert row.committed_value == row_prev.committed_value # 8. RWC !=0 except for Tag.Start if row.tag() != Tag.Start: @@ -510,11 +499,11 @@ def keys_rwc_to_limbs_in_order(row: Row) -> List[FQ]: elif row.tag() == Tag.Stack: check_stack(row, row_prev) elif row.tag() == Tag.Storage: - check_storage(row, row_prev, tables) + check_storage(row, row_prev, row_next, tables) elif row.tag() == Tag.CallContext: check_call_context(row, row_prev) elif row.tag() == Tag.Account: - check_account(row, row_prev, tables) + check_account(row, row_prev, row_next, tables) elif row.tag() == Tag.TxRefund: check_tx_refund(row, row_prev) elif row.tag() == Tag.TxAccessListAccountStorage: @@ -545,7 +534,7 @@ class Operation(NamedTuple): field_tag: U256 storage_key: U256 value: FQ - aux0: FQ + committed_value: FQ class StartOp(Operation): @@ -738,104 +727,105 @@ def __new__(self, rw_counter: int, rw: RW, tx_id: int, field_tag: TxReceiptField # fmt: on -class Assigner: - mpt_counter: FQ +def op2row( + op: Operation, + randomness: FQ, + root: FQ, +) -> Row: + rw_counter = FQ(op.rw_counter) + is_write = FQ(0) if op.rw == RW.Read else FQ(1) + tag = FQ(op.tag) + id = FQ(op.id) + address = FQ(op.address) + address_bytes = op.address.to_bytes(20, "little") + address_limbs = tuple( + [FQ(address_bytes[i] + 2**8 * address_bytes[i + 1]) for i in range(0, 20, 2)] + ) + field_tag = FQ(op.field_tag) + storage_key_rlc = RLC(op.storage_key, randomness) + storage_key = storage_key_rlc.expr() + storage_key_bytes = tuple([FQ(x) for x in storage_key_rlc.le_bytes]) + + keys = (tag, id, address, field_tag, storage_key) + + value = FQ(op.value) + committed_value = FQ(op.committed_value) + + return Row( + rw_counter, + is_write, + keys, + address_limbs, # type: ignore + storage_key_bytes, # type: ignore + value, + committed_value, + root, + ) - def __init__(self): - self.mpt_counter = FQ(0) - def op2row(self, op: Operation, randomness: FQ) -> Row: - rw_counter = FQ(op.rw_counter) - is_write = FQ(0) if op.rw == RW.Read else FQ(1) - tag = FQ(op.tag) - id = FQ(op.id) - address = FQ(op.address) - address_bytes = op.address.to_bytes(20, "little") - address_limbs = tuple( - [FQ(address_bytes[i] + 2**8 * address_bytes[i + 1]) for i in range(0, 20, 2)] - ) - field_tag = FQ(op.field_tag) - storage_key_rlc = RLC(op.storage_key, randomness) - storage_key = storage_key_rlc.expr() - storage_key_bytes = tuple([FQ(x) for x in storage_key_rlc.le_bytes]) - value = FQ(op.value) - aux0 = FQ(op.aux0) +# Generate the advice Rows from a list of Operations +def assign_state_circuit(ops: List[Operation], randomness: FQ) -> List[Row]: + mpt_updates = _mock_mpt_updates(ops, randomness) + + # MPT keys for each Storage and Account row, and None otherwise. + mpt_keys = [_mpt_key(op) for op in ops] + # MPT updates for each Storage and Account row, and None otherwise. + updates = [None if key is None else mpt_updates.get(key) for key in mpt_keys] + # root_prev for each Storage and Account row, and None otherwise. + roots = [None if update is None else update.root_prev.expr() for update in updates] + + # With real mpt updates, the final root would be obtained from the public + # input. For _mock_mpt_updates, it's just 3 + 5 * number of MPT updates. + final_root = FQ(3 + 5 * len(mpt_updates)) + roots.append(final_root) + + # Fill in the None roots with the first non-None value that comes after it. + root: FQ = final_root + for i in reversed(range(len(roots))): + maybe_root = roots[i] + if maybe_root is None: + roots[i] = root + else: + root = maybe_root - if tag == FQ(Tag.Storage) or tag == FQ(Tag.Account): - self.mpt_counter += 1 + rows = [] + for op, maybe_root in zip(ops, roots[1:]): + assert maybe_root is not None + rows.append(op2row(op, randomness, maybe_root)) + return rows - # fmt: off - return Row(rw_counter, is_write, - # keys - (tag, id, address, field_tag, storage_key), address_limbs, storage_key_bytes, # type: ignore - value, (aux0,), # values - self.mpt_counter) - # fmt: on +def mpt_table_from_ops(ops: List[Operation], randomness: FQ) -> Set[MPTTableRow]: + return set(_mock_mpt_updates(ops, randomness).values()) -# def rw_table_tag2tag(tag: RWTableTag) -> FQ: -# ret = None -# if tag == RWTableTag.Memory: -# ret = Tag.Memory -# elif tag == RWTableTag.Stack: -# ret = Tag.Stack -# elif tag == RWTableTag.Storage: -# ret = Tag.Storage -# elif tag == RWTableTag.CallContext: -# ret = Tag.CallContext -# elif tag == RWTableTag.Account: -# ret = Tag.Account -# elif tag == RWTableTag.TxRefund: -# ret = Tag.TxRefund -# elif tag == RWTableTag.TxAccessListAccount: -# ret = Tag.TxAccessListAccount -# elif tag == RWTableTag.TxAccessListAccountStorage: -# ret = Tag.TxAccessListAccountStorage -# elif tag == RWTableTag.AccountDestructed: -# ret = Tag.AccountDestructed -# else: -# raise ValueError("Unreacheable") -# -# return FQ(ret) -# Generate the advice Rows from a list of Operations -def assign_state_circuit(ops: List[Operation], randomness: FQ) -> List[Row]: - assigner = Assigner() - rows = [assigner.op2row(op, randomness) for op in ops] - return rows +def _mpt_key(op: Operation) -> Optional[Tuple[FQ, FQ, FQ]]: + if op.tag != Tag.Account and op.tag != Tag.Storage: + return None + return (FQ(op.address), FQ(op.field_tag), FQ(op.storage_key)) -def mpt_table_from_ops( - ops_or_rows: Union[List[Operation], List[Row]], randomness: FQ -) -> Set[MPTTableRow]: - if isinstance(ops_or_rows[0], Operation): - rows = assign_state_circuit(cast(List[Operation], ops_or_rows), randomness) - else: - rows = cast(List[Row], ops_or_rows) - - mpt_rows = [] - for (idx, row) in enumerate(rows): - value_prev = row.auxs[0] - if idx > 0: - row_prev = rows[idx - 1] - if all_keys_eq(row, row_prev): - value_prev = row_prev.value - - if row.keys[0] == FQ(Tag.Storage): - mpt_rows.append( - MPTTableRow( - row.mpt_counter, - FQ(MPTTableTag.Storage), - row.keys[2], - row.keys[4], - row.value, - value_prev, - ) - ) - elif row.keys[0] == FQ(Tag.Account): - mpt_rows.append( - MPTTableRow( - row.mpt_counter, row.keys[3], row.keys[2], row.keys[4], row.value, value_prev - ) - ) - return set(mpt_rows) +def _mock_mpt_updates(ops: List[Operation], randomness: FQ) -> Dict[Tuple[FQ, FQ, FQ], MPTTableRow]: + # makes fake mpt updates for a list of rows. the state root starts at 5 and + # is incremented by 3 for each Account or Storage MPT update. + mpt_map = {} + + root = 3 + for op in ops: + mpt_key = _mpt_key(op) + if mpt_key is None or mpt_key in mpt_map: + continue + + new_root = root + 5 + mpt_map[mpt_key] = MPTTableRow( + FQ(op.address), + FQ(op.field_tag), + RLC(op.storage_key, randomness).expr(), + FQ(new_root), + FQ(root), + op.value, + op.committed_value, + ) + root = new_root + + return mpt_map diff --git a/tests/test_state_circuit.py b/tests/test_state_circuit.py index 2b1ce6a68..bddf70042 100644 --- a/tests/test_state_circuit.py +++ b/tests/test_state_circuit.py @@ -25,8 +25,9 @@ def verify( ok = True for (idx, row) in enumerate(rows): row_prev = rows[(idx - 1) % len(rows)] + row_next = rows[(idx + 1) % len(rows)] try: - check_state_row(row, row_prev, tables, randomness) + check_state_row(row, row_prev, row_next, tables, randomness) except AssertionError as e: if success: traceback.print_exc() @@ -95,6 +96,22 @@ def test_state_ok(): verify(ops, tables, randomness) +def test_mpt_updates_ok(): + # fmt: off + ops = [ + StartOp(), + + StorageOp(rw_counter=7, rw=RW.Read, tx_id=1, addr=0x12345678, key=0x1516, value=rlc(789), committed_value=rlc(789)), + StorageOp(rw_counter=8, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x4959, value=rlc(38491), committed_value=rlc(98765)), + + AccountOp(rw_counter=12, rw=RW.Write, addr=0x12345678, field_tag=AccountFieldTag.Nonce, value=FQ(1), committed_value=FQ(0)), + AccountOp(rw_counter=13, rw=RW.Read, addr=0x12345678, field_tag=AccountFieldTag.Balance, value=FQ(3), committed_value=FQ(0)), + ] + # fmt: on + tables = Tables(mpt_table_from_ops(ops, randomness)) + verify(ops, tables, randomness) + + def test_state_bad_key2(): # fmt: off ops = [ @@ -360,30 +377,3 @@ def test_storage_committed_value_bad(): # fmt: on tables = Tables(mpt_table_from_ops(ops, randomness)) verify(ops, tables, randomness, success=False) - - -def test_mpt_counter_bad(): - # fmt: off - ops = [ - StartOp(), - StorageOp(rw_counter=1, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(789), committed_value=rlc(789)), - StorageOp(rw_counter=2, rw=RW.Write, tx_id=1, addr=0x12345678, key=0x15161718, value=rlc(123), committed_value=rlc(789)), - ] - # fmt: on - rows = assign_state_circuit(ops, r) - # mpt_counter goes from 1 to 3 - rows[2] = rows[2]._replace(mpt_counter=FQ(3)) - tables = Tables(mpt_table_from_ops(ops, randomness)) - verify(rows, tables, randomness, success=False) - - # fmt: off - ops = [ - StartOp(), - StackOp(rw_counter=1, rw=RW.Write, call_id=1, stack_ptr=1021, value=rlc(4321)), - ] - # fmt: on - rows = assign_state_circuit(ops, r) - # mpt_counter increases when tag is not Account or Storage - rows[1] = rows[1]._replace(mpt_counter=FQ(1)) - tables = Tables(mpt_table_from_ops(ops, randomness)) - verify(rows, tables, randomness, success=False)