Skip to content
This repository was archived by the owner on Jul 5, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/zkevm_specs/bytecode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Sequence, Union, Tuple, Set
from collections import namedtuple
from .util import keccak256, fp_add, fp_mul, RLCStore
from .util import keccak256, fp_add, fp_mul, RLC
from .evm.opcode import get_push_size
from .encoding import U8, U256, is_circuit_code

Expand Down Expand Up @@ -87,7 +87,7 @@ def check_bytecode_row(


# Populate the circuit matrix
def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_store: RLCStore):
def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], randomness: int):
# All rows are usable in this emulation
last_row_offset = 2 ** k - 1

Expand All @@ -103,7 +103,7 @@ def assign_bytecode_circuit(k: int, bytecodes: Sequence[UnrolledBytecode], rlc_s
push_data_left = byte_push_size if is_code else push_data_left - 1

# Add the byte to the accumulator
hash_rlc = fp_add(fp_mul(hash_rlc, rlc_store.randomness), row[2])
hash_rlc = fp_add(fp_mul(hash_rlc, randomness), row[2])

# Set the data for this row
rows.append(
Expand Down Expand Up @@ -162,10 +162,10 @@ def assign_push_table():


# Generate keccak table
def assign_keccak_table(bytecodes: Sequence[bytes], rlc_store: RLCStore):
def assign_keccak_table(bytecodes: Sequence[bytes], randomness: int):
keccak_table = []
for bytecode in bytecodes:
hash = rlc_store.to_rlc(keccak256(bytecode), 32)
rlc = rlc_store.to_rlc(list(reversed(bytecode)))
hash = RLC(bytes(reversed(keccak256(bytecode))), randomness)
rlc = RLC(bytes(reversed(bytecode)), randomness, len(bytecode))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Which is the purpose of applying the RLC to the keccak inputs? Hasn't it been done previously in assign_bytecode_circuit?

Also, why do we need to reversethe bytecodet?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is the purpose of applying the RLC to the keccak inputs? Hasn't it been done previously in assign_bytecode_circuit?

This is a mocking keccak_table, for lookup the hash by RLC of bytecode and length.

Also, why do we need to reversethe bytecodet?

Because in bytecode circuit, it accumulates the bytes in big-endian order, so we need to reverse it for RLC, which takes input as little-endian.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes complete sense!

keccak_table.append((rlc, len(bytecode), hash))
return keccak_table
21 changes: 19 additions & 2 deletions src/zkevm_specs/evm/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from typing import Callable, Dict

from ..execution_state import ExecutionState

from .begin_tx import *
from .end_tx import *
from .end_block import *

# Opcode's successful cases
from .add import *
from .push import *
from .jump import *
from .jumpi import *
from .push import *
from .block_coinbase import *
from .caller import *

# Error cases

EXECUTION_STATE_IMPL: Dict[ExecutionState, Callable] = {
ExecutionState.BeginTx: begin_tx,
ExecutionState.EndTx: end_tx,
ExecutionState.EndBlock: end_block,
ExecutionState.ADD: add,
ExecutionState.CALLER: caller,
ExecutionState.COINBASE: coinbase,
ExecutionState.JUMP: jump,
ExecutionState.JUMPI: jumpi,
ExecutionState.PUSH: push,
}
2 changes: 1 addition & 1 deletion src/zkevm_specs/evm/execution/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def add(instruction: Instruction):
instruction.select(is_sub, a, c),
)

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(3),
program_counter=Transition.delta(1),
Expand Down
73 changes: 42 additions & 31 deletions src/zkevm_specs/evm/execution/begin_tx.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
from ...util import GAS_COST_TX, GAS_COST_CREATION_TX, EMPTY_CODE_HASH
from ..execution_state import ExecutionState
from ..instruction import Instruction, Transition
from ..table import CallContextFieldTag, TxContextFieldTag, RW, AccountFieldTag
from ..precompiled import PrecompiledAddress
from ..table import CallContextFieldTag, TxContextFieldTag, AccountFieldTag


def begin_tx(instruction: Instruction, is_first_step: bool = False):
instruction.constrain_equal(instruction.curr.call_id, instruction.curr.rw_counter)
def begin_tx(instruction: Instruction):
call_id = instruction.curr.rw_counter

tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId)
rw_counter_end_of_reversion = instruction.call_context_lookup(CallContextFieldTag.RWCounterEndOfReversion)
is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent)
tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=call_id)
rw_counter_end_of_reversion = instruction.call_context_lookup(
CallContextFieldTag.RwCounterEndOfReversion, call_id=call_id
)
is_persistent = instruction.call_context_lookup(CallContextFieldTag.IsPersistent, call_id=call_id)

if is_first_step:
if instruction.is_first_step:
instruction.constrain_equal(instruction.curr.rw_counter, 1)
instruction.constrain_equal(tx_id, 1)

tx_caller_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CallerAddress)
tx_callee_address = instruction.tx_lookup(tx_id, TxContextFieldTag.CalleeAddress)
tx_is_create = instruction.tx_lookup(tx_id, TxContextFieldTag.IsCreate)
tx_value = instruction.tx_lookup(tx_id, TxContextFieldTag.Value)
tx_call_data_length = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataLength)
tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress)
tx_callee_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CalleeAddress)
tx_is_create = instruction.tx_context_lookup(tx_id, TxContextFieldTag.IsCreate)
tx_value = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Value)
tx_call_data_length = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataLength)

# Verify nonce
tx_nonce = instruction.tx_lookup(tx_id, TxContextFieldTag.Nonce)
tx_nonce = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Nonce)
nonce, nonce_prev = instruction.account_write(tx_caller_address, AccountFieldTag.Nonce)
instruction.constrain_equal(tx_nonce, nonce_prev)
instruction.constrain_equal(nonce, nonce_prev + 1)

# TODO: Implement EIP 1559 (currently it supports legacy transaction format)
# Calculate gas fee
tx_gas = instruction.tx_lookup(tx_id, TxContextFieldTag.Gas)
tx_gas_price = instruction.tx_lookup(tx_id, TxContextFieldTag.GasPrice)
tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas)
tx_gas_price = instruction.tx_gas_price(tx_id)
gas_fee, carry = instruction.mul_word_by_u64(tx_gas_price, tx_gas)
instruction.constrain_zero(carry)

# TODO: Handle gas cost of tx level access list (EIP 2930)
tx_call_data_gas_cost = instruction.tx_lookup(tx_id, TxContextFieldTag.CallDataGasCost)
gas_left = tx_gas - (53000 if tx_is_create else 21000) - tx_call_data_gas_cost
tx_call_data_gas_cost = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallDataGasCost)
gas_left = tx_gas - (GAS_COST_CREATION_TX if tx_is_create else GAS_COST_TX) - tx_call_data_gas_cost
instruction.constrain_gas_left_not_underflow(gas_left)

# Prepare access list of caller and callee
Expand All @@ -47,27 +51,27 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False):
tx_caller_address,
tx_callee_address,
tx_value,
gas_fee=gas_fee,
is_persistent=is_persistent,
rw_counter_end_of_reversion=rw_counter_end_of_reversion,
gas_fee,
is_persistent,
rw_counter_end_of_reversion,
)

if tx_is_create:
# TODO: Verify receiver address
# TODO: Set opcode_source to tx_id
# TODO: Verify created address
# TODO: Decide what code_source should be (tx_id or hash of creation code)
raise NotImplementedError
elif tx_callee_address in list(PrecompiledAddress):
# TODO: Handle precompile
raise NotImplementedError
else:
code_hash, _ = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash)
code_hash = instruction.account_read(tx_callee_address, AccountFieldTag.CodeHash)

# Setup next call's context
# Note that:
# - CallerCallId, ReturnDataOffset, ReturnDataLength, Result
# should never be used in root call, so unnecessary to check
# - TxId is propagated from previous step or constraint to 1 if is_first_step
# - IsPersistent will be verified in the end of tx
# - CallerId, ReturnDataOffset, ReturnDataLength
# should never be used in root call, so unnecessary to be checked
# - TxId is checked from previous step or constraint to 1 if is_first_step
# - IsSuccess, IsPersistent will be verified in the end of tx
for (tag, value) in [
(CallContextFieldTag.Depth, 1),
(CallContextFieldTag.CallerAddress, tx_caller_address),
Expand All @@ -77,14 +81,21 @@ def begin_tx(instruction: Instruction, is_first_step: bool = False):
(CallContextFieldTag.Value, tx_value),
(CallContextFieldTag.IsStatic, False),
]:
instruction.constrain_equal(instruction.call_context_lookup(tag), value)
instruction.constrain_equal(instruction.call_context_lookup(tag, call_id=call_id), value)

instruction.constrain_new_context_state_transition(
instruction.step_state_transition_to_new_context(
rw_counter=Transition.delta(16),
call_id=Transition.persistent(),
call_id=Transition.to(call_id),
is_root=Transition.to(True),
is_create=Transition.to(False),
opcode_source=Transition.to(code_hash),
code_source=Transition.to(code_hash),
gas_left=Transition.to(gas_left),
state_write_counter=Transition.to(2),
)

# Constrain either:
# - is_empty_code and is_to_end_tx
# - (not is_empty_code) and (not is_to_end_tx)
is_empty_code = instruction.is_equal(code_hash, instruction.int_to_rlc(EMPTY_CODE_HASH, 32))
is_to_end_tx = instruction.is_equal(instruction.next.execution_state, ExecutionState.EndTx)
instruction.constrain_equal(is_empty_code + is_to_end_tx, 2 * is_empty_code * is_to_end_tx)
7 changes: 5 additions & 2 deletions src/zkevm_specs/evm/execution/block_coinbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ def coinbase(instruction: Instruction):
# check block table for coinbase address
instruction.constrain_equal(
address,
instruction.bytes_to_rlc(instruction.int_to_bytes(instruction.block_lookup(BlockContextFieldTag.Coinbase), 20)),
instruction.int_to_rlc(
instruction.block_context_lookup(BlockContextFieldTag.Coinbase),
20,
),
)

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(1),
program_counter=Transition.delta(1),
Expand Down
10 changes: 4 additions & 6 deletions src/zkevm_specs/evm/execution/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def caller(instruction: Instruction):
# check [rw_table, call_context] table for caller address
instruction.constrain_equal(
address,
instruction.bytes_to_rlc(
instruction.int_to_bytes(
instruction.call_context_lookup(CallContextFieldTag.CallerAddress),
20,
)
instruction.int_to_rlc(
instruction.call_context_lookup(CallContextFieldTag.CallerAddress),
20,
),
)

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(2),
program_counter=Transition.delta(1),
Expand Down
30 changes: 30 additions & 0 deletions src/zkevm_specs/evm/execution/end_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from ..instruction import Instruction, Transition
from ..table import CallContextFieldTag


# TODO: Introduce constrain_instance to constrain the equality between witness
# and public input, for total_tx and total_rw


def end_block(instruction: Instruction):
if instruction.is_last_step:
# Verify final step has tx_id identical to the tx amount in tx_table.
total_tx = instruction.call_context_lookup(CallContextFieldTag.TxId)
instruction.constrain_equal(
total_tx,
max([tx_id for tx_id, *_ in instruction.tables.tx_table]),
)

# Verify rw_counter counts to identical rw amount in rw_table to ensure
# there is no malicious insertion.
total_rw = instruction.curr.rw_counter + 1 # extra 1 from the tx_id lookup
instruction.constrain_equal(
total_rw,
len(instruction.tables.rw_table),
)
else:
# Propagate rw_counter and call_id all the way down
instruction.constrain_step_state_transition(
rw_counter=Transition.same(),
call_id=Transition.same(),
)
47 changes: 47 additions & 0 deletions src/zkevm_specs/evm/execution/end_tx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from ...util import N_BYTES_GAS, MAX_REFUND_QUOTIENT_OF_GAS_USED
from ..execution_state import ExecutionState
from ..instruction import Instruction, Transition
from ..table import BlockContextFieldTag, CallContextFieldTag, TxContextFieldTag


def end_tx(instruction: Instruction):
tx_id = instruction.call_context_lookup(CallContextFieldTag.TxId)

# Handle gas refund (refund is capped to gas_used // MAX_REFUND_QUOTIENT_OF_GAS_USED in EIP 3529)
tx_gas = instruction.tx_context_lookup(tx_id, TxContextFieldTag.Gas)
gas_used = tx_gas - instruction.curr.gas_left
max_refund, _ = instruction.constant_divmod(gas_used, MAX_REFUND_QUOTIENT_OF_GAS_USED, N_BYTES_GAS)
refund = instruction.tx_refund_read(tx_id)
effective_refund = instruction.min(max_refund, refund, 8)

# Add effective_refund * gas_price back to caller's balance
tx_gas_price = instruction.tx_gas_price(tx_id)
value, carry = instruction.mul_word_by_u64(tx_gas_price, instruction.curr.gas_left + effective_refund)
instruction.constrain_zero(carry)
tx_caller_address = instruction.tx_context_lookup(tx_id, TxContextFieldTag.CallerAddress)
instruction.add_balance(tx_caller_address, [value])

# Add gas_used * effective_tip to coinbase's balance
base_fee = instruction.block_context_lookup(BlockContextFieldTag.BaseFee)
effective_tip, _ = instruction.sub_word(tx_gas_price, base_fee)
reward, carry = instruction.mul_word_by_u64(effective_tip, gas_used)
instruction.constrain_zero(carry)
coinbase = instruction.block_context_lookup(BlockContextFieldTag.Coinbase)
instruction.add_balance(coinbase, [reward])

# Go to next transaction
if instruction.next.execution_state == ExecutionState.BeginTx:
# Check next tx_id is increased by 1
instruction.constrain_equal(
instruction.call_context_lookup(CallContextFieldTag.TxId, call_id=instruction.next.rw_counter),
tx_id + 1,
)

# Do step state transition for rw_counter
instruction.constrain_step_state_transition(rw_counter=Transition.delta(5))
# Go to end of block
elif instruction.next.execution_state == ExecutionState.EndBlock:
# Do step state transition for rw_counter
instruction.constrain_step_state_transition(rw_counter=Transition.delta(4))
else:
raise ValueError("Unreacheable")
5 changes: 3 additions & 2 deletions src/zkevm_specs/evm/execution/jump.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ...util.param import N_BYTES_PROGRAM_COUNTER
from ..instruction import Instruction, Transition
from ..opcode import Opcode

Expand All @@ -11,13 +12,13 @@ def jump(instruction: Instruction):
dest = instruction.stack_pop()

# Get `dest` raw value in max 8 bytes
dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8))
dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER)

# Verify `dest` is code within byte code table
# assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True)
instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True))

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(1),
program_counter=Transition.to(dest_value),
Expand Down
5 changes: 3 additions & 2 deletions src/zkevm_specs/evm/execution/jumpi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ...util.param import N_BYTES_PROGRAM_COUNTER
from ..instruction import Instruction, Transition
from ..opcode import Opcode

Expand All @@ -16,12 +17,12 @@ def jumpi(instruction: Instruction):
pc_diff = 1
else:
# Get `dest` raw value in max 8 bytes
dest_value = instruction.bytes_to_int(instruction.rlc_to_bytes(dest, 8))
dest_value = instruction.rlc_to_int_exact(dest, N_BYTES_PROGRAM_COUNTER)
pc_diff = dest_value - instruction.curr.program_counter
# assert Opcode.JUMPDEST == instruction.opcode_lookup_at(dest_value, True)
instruction.constrain_equal(Opcode.JUMPDEST, instruction.opcode_lookup_at(dest_value, True))

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(2),
program_counter=Transition.delta(pc_diff),
Expand Down
8 changes: 4 additions & 4 deletions src/zkevm_specs/evm/execution/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ def push(instruction: Instruction):
num_additional_pushed = num_pushed - 1

value = instruction.stack_push()
value_bytes = instruction.rlc_to_bytes(value, 32)
value_le_bytes = instruction.rlc_to_le_bytes(value)
selectors = instruction.continuous_selectors(num_additional_pushed, 31)

for idx in range(32):
index = instruction.curr.program_counter + num_pushed - idx
if idx == 0 or selectors[idx - 1]:
instruction.constrain_equal(value_bytes[idx], instruction.opcode_lookup_at(index, False))
instruction.constrain_equal(value_le_bytes[idx], instruction.opcode_lookup_at(index, False))
else:
instruction.constrain_zero(value_bytes[idx])
instruction.constrain_zero(value_le_bytes[idx])

instruction.constrain_same_context_state_transition(
instruction.step_state_transition_in_same_context(
opcode,
rw_counter=Transition.delta(1),
program_counter=Transition.delta(1 + num_pushed),
Expand Down
2 changes: 2 additions & 0 deletions src/zkevm_specs/evm/execution_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class ExecutionState(IntEnum):
"""

BeginTx = auto()
EndTx = auto()
EndBlock = auto()

# Opcode's successful cases
STOP = auto()
Expand Down
Loading