diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py index e75f01f2279..1b4974702b3 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py @@ -1306,7 +1306,6 @@ def base_test_parametrizer_func( fixture_source_url: str, gas_benchmark_value: int, fixed_opcode_count: int | None, - witness_generator: Any, ) -> Any: """ Fixture used to instantiate an auto-fillable BaseTest object from @@ -1430,11 +1429,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: _info_metadata=t8n._info_metadata, ) - # Generate witness data if witness functionality is enabled via - # the witness plugin - if witness_generator is not None: - witness_generator(fixture) - fixture_path = fixture_collector.add_fixture( node_to_test_info(request.node), fixture, diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py deleted file mode 100644 index 00596675919..00000000000 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Pytest plugin for witness functionality. - -Provides --witness command-line option that checks for the witness-filler tool -in PATH and generates execution witness data for blockchain test fixtures when -enabled. -""" - -import shutil -import subprocess -from typing import Callable, List - -import pytest - -from execution_testing.base_types import EthereumTestRootModel -from execution_testing.fixtures.blockchain import ( - BlockchainFixture, - FixtureBlock, - WitnessChunk, -) -from execution_testing.forks import Paris - - -class WitnessFillerResult(EthereumTestRootModel[List[WitnessChunk]]): - """ - Model that defines the expected result from the `witness-filler` command. - """ - - root: List[WitnessChunk] - - -class Merge(Paris): - """ - Paris fork that serializes as 'Merge' for witness-filler compatibility. - - IMPORTANT: This class MUST be named 'Merge' (not 'MergeForWitness' or - similar) because the class name is used directly in Pydantic serialization, - and witness-filler expects exactly 'Merge' for this fork. - """ - - pass - - -def pytest_addoption(parser: pytest.Parser) -> None: - """Add witness command-line options to pytest.""" - witness_group = parser.getgroup( - "witness", "Arguments for witness functionality" - ) - witness_group.addoption( - "--witness", - "--witness-the-fitness", - action="store_true", - dest="witness", - default=False, - help=( - "Generate execution witness data for blockchain test fixtures using the " - "witness-filler tool (must be installed separately)." - ), - ) - - -def pytest_configure(config: pytest.Config) -> None: - """ - Pytest hook called after command line options have been parsed. - - If --witness is enabled, checks that the witness-filler tool is available - in PATH. - """ - if config.getoption("witness"): - # Check if witness-filler binary is available in PATH - if not shutil.which("witness-filler"): - pytest.exit( - "witness-filler tool not found in PATH. Please build and install witness-filler " - "from https://github.com/kevaundray/reth.git before using --witness flag.\n" - "Example: cargo install --git https://github.com/kevaundray/reth.git " - "witness-filler", - 1, - ) - - -@pytest.fixture -def witness_generator( - request: pytest.FixtureRequest, -) -> Callable[[BlockchainFixture], None] | None: - """ - Provide a witness generator function if --witness is enabled. - - Returns: None if witness functionality is disabled. Callable that generates - witness data for a BlockchainFixture if enabled. - """ - if not request.config.getoption("witness"): - return None - - def generate_witness(fixture: BlockchainFixture) -> None: - """ - Generate witness data for a blockchain fixture using the witness-filler - tool. - """ - if not isinstance(fixture, BlockchainFixture): - return None - - # Hotfix: witness-filler expects "Merge" but execution-spec-tests uses - # "Paris" - original_fork = None - if fixture.fork is Paris: - original_fork = fixture.fork - fixture.fork = Merge - - try: - result = subprocess.run( - ["witness-filler"], - input=fixture.model_dump_json(by_alias=True), - text=True, - capture_output=True, - ) - finally: - if original_fork is not None: - fixture.fork = original_fork - - if result.returncode != 0: - raise RuntimeError( - f"witness-filler tool failed with exit code {result.returncode}. " - f"stderr: {result.stderr}" - ) - - try: - result_model = WitnessFillerResult.model_validate_json( - result.stdout - ) - witnesses = result_model.root - - for i, witness in enumerate(witnesses): - if i < len(fixture.blocks): - block = fixture.blocks[i] - if isinstance(block, FixtureBlock): - block.execution_witness = witness - except Exception as e: - raise RuntimeError( - f"Failed to parse witness data from witness-filler tool. " - f"Output was: {result.stdout[:500]}{'...' if len(result.stdout) > 500 else ''}" - ) from e - - return generate_witness diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini b/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini index 37b44bf643f..74d65dd5e74 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini +++ b/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini @@ -10,7 +10,6 @@ addopts = -p execution_testing.cli.pytest_commands.plugins.forks.forks -p execution_testing.cli.pytest_commands.plugins.concurrency -p execution_testing.cli.pytest_commands.plugins.filler.pre_alloc - -p execution_testing.cli.pytest_commands.plugins.filler.witness -p execution_testing.cli.pytest_commands.plugins.filler.ported_tests -p execution_testing.cli.pytest_commands.plugins.filler.static_filler -p execution_testing.cli.pytest_commands.plugins.shared.benchmarking diff --git a/packages/testing/src/execution_testing/client_clis/cli_types.py b/packages/testing/src/execution_testing/client_clis/cli_types.py index 190192ffcca..bd7faf27523 100644 --- a/packages/testing/src/execution_testing/client_clis/cli_types.py +++ b/packages/testing/src/execution_testing/client_clis/cli_types.py @@ -25,6 +25,7 @@ TransactionException, UndefinedException, ) +from execution_testing.fixtures.blockchain import ExecutionWitness from execution_testing.logging import ( get_logger, ) @@ -289,6 +290,7 @@ class Result(CamelModel): ] = None traces: Traces | None = None opcode_count: OpcodeCount | None = None + execution_witness: ExecutionWitness | None = None TRaw = TypeVar("TRaw") diff --git a/packages/testing/src/execution_testing/fixtures/blockchain.py b/packages/testing/src/execution_testing/fixtures/blockchain.py index 786d6dd34ed..5b8ca852456 100644 --- a/packages/testing/src/execution_testing/fixtures/blockchain.py +++ b/packages/testing/src/execution_testing/fixtures/blockchain.py @@ -448,6 +448,14 @@ def from_fixture_header( ] +class ExecutionWitness(CamelModel): + """Execution witness containing RLP-encoded trie nodes and bytecodes accessed during block execution.""" + + nodes: List[str] + bytecodes: List[str] = [] + ancestors: List[str] = [] + + class FixtureEngineNewPayload(CamelModel): """ Representation of the `engine_newPayloadVX` information to be sent using @@ -468,6 +476,7 @@ class FixtureEngineNewPayload(CamelModel): ] | None ) = None + execution_witness: ExecutionWitness | None = None def valid(self) -> bool: """Return whether the payload is valid.""" @@ -581,24 +590,6 @@ def from_withdrawal(cls, w: WithdrawalGeneric) -> Self: return cls(**w.model_dump()) -class WitnessChunk(CamelModel): - """Represents execution witness data for a block.""" - - state: List[str] - codes: List[str] - keys: List[str] - headers: List[str] - - @classmethod - def parse_witness_chunks(cls, s: str) -> List[Self]: - """ - Parse multiple witness chunks from JSON string. - - Returns a list of WitnessChunk instances parsed from the JSON array. - """ - return [cls(**obj) for obj in json.loads(s)] - - class FixtureBlockBase(CamelModel): """ Representation of an Ethereum block within a test Fixture without RLP @@ -625,7 +616,7 @@ def strip_block_number_computed_field(cls, data: Any) -> Any: default_factory=list, alias="uncleHeaders" ) withdrawals: List[FixtureWithdrawal] | None = None - execution_witness: WitnessChunk | None = None + execution_witness: ExecutionWitness | None = None fork: Fork | None = Field(None, exclude=True) @computed_field(alias="blocknumber") # type: ignore[prop-decorator] diff --git a/packages/testing/src/execution_testing/specs/blockchain.py b/packages/testing/src/execution_testing/specs/blockchain.py index 06d6bebfc9a..7d4083117e5 100644 --- a/packages/testing/src/execution_testing/specs/blockchain.py +++ b/packages/testing/src/execution_testing/specs/blockchain.py @@ -96,6 +96,7 @@ def environment_from_parent_header(parent: "FixtureHeader") -> "Environment": parent_gas_limit=parent.gas_limit, parent_ommers_hash=parent.ommers_hash, block_hashes={parent.number: parent.block_hash}, + block_headers={parent.number: parent.rlp}, ) @@ -115,6 +116,9 @@ def apply_new_parent( block_hashes = env.block_hashes.copy() block_hashes[new_parent.number] = new_parent.block_hash updated["block_hashes"] = block_hashes + block_headers = env.block_headers.copy() + block_headers[new_parent.number] = new_parent.rlp + updated["block_headers"] = block_headers return env.copy(**updated) @@ -380,6 +384,7 @@ def get_fixture_block(self) -> FixtureBlock | InvalidFixtureBlock: if self.withdrawals is not None else None ), + execution_witness=self.result.execution_witness, fork=self.fork, ).with_rlp(txs=self.txs) @@ -414,6 +419,7 @@ def get_fixture_engine_new_payload(self) -> FixtureEngineNewPayload: else None, validation_error=self.expected_exception, error_code=self.engine_api_error_code, + execution_witness=self.result.execution_witness, ) def verify_transactions( diff --git a/packages/testing/src/execution_testing/test_types/block_types.py b/packages/testing/src/execution_testing/test_types/block_types.py index 2c73a558b98..a7ad164f653 100644 --- a/packages/testing/src/execution_testing/test_types/block_types.py +++ b/packages/testing/src/execution_testing/test_types/block_types.py @@ -137,6 +137,9 @@ def strip_computed_fields(cls, data: Any) -> Any: parent_beacon_block_root: Hash | None = Field(None) block_hashes: Dict[ZeroPaddedHexNumber, Hash] = Field(default_factory=dict) + block_headers: Dict[ZeroPaddedHexNumber, Bytes] = Field( + default_factory=dict + ) ommers: List[Hash] = Field(default_factory=list) withdrawals: List[Withdrawal] | None = Field(None) extra_data: Bytes = Field(Bytes(b"\x00"), exclude=True) diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index 1d8bbcc106b..7c81484c1f7 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -58,7 +58,10 @@ increment_nonce, modify_state, set_account_balance, + set_witness_metadata, state_root, + track_block_hash_access, + track_bytecode_access, ) from .transactions import ( AccessListTransaction, @@ -191,6 +194,36 @@ def get_last_256_block_hashes(chain: BlockChain) -> List[Hash32]: return recent_block_hashes +def get_last_256_block_headers(chain: BlockChain) -> List[Bytes]: + """ + Obtain the list of RLP-encoded headers of the previous 256 blocks. + + This function will return less headers for the first 256 blocks. + The headers are parallel to the hashes from get_last_256_block_hashes. + + Parameters + ---------- + chain : + History and current state. + + Returns + ------- + recent_block_headers : `List[Bytes]` + RLP-encoded headers of recent 256 blocks in order of increasing number. + + """ + recent_blocks = chain.blocks[-256:] + if len(recent_blocks) == 0: + return [] + + recent_block_headers: List[Bytes] = [] + for block in recent_blocks: + header_rlp = rlp.encode(block.header) + recent_block_headers.append(header_rlp) + + return recent_block_headers + + def state_transition(chain: BlockChain, block: Block) -> None: """ Attempts to apply a block to an existing block chain. @@ -235,6 +268,13 @@ def state_transition(chain: BlockChain, block: Block) -> None: parent_beacon_block_root=block.header.parent_beacon_block_root, ) + # Set witness metadata if tracking is enabled + set_witness_metadata( + block_env.state, + block.header.number, + get_last_256_block_headers(chain), + ) + block_output = apply_body( block_env=block_env, transactions=block.transactions, @@ -676,6 +716,7 @@ def process_checked_system_transaction( """ system_contract_code = get_account(block_env.state, target_address).code + track_bytecode_access(block_env.state, system_contract_code) if len(system_contract_code) == 0: raise InvalidBlock( @@ -724,6 +765,7 @@ def process_unchecked_system_transaction( """ system_contract_code = get_account(block_env.state, target_address).code + track_bytecode_access(block_env.state, system_contract_code) return process_system_transaction( block_env, target_address, @@ -770,6 +812,9 @@ def apply_body( data=block_env.parent_beacon_block_root, ) + # Track parent block access for witness (EIP-2935 system call) + track_block_hash_access(block_env.state, block_env.number - Uint(1)) + process_unchecked_system_transaction( block_env=block_env, target_address=HISTORY_STORAGE_ADDRESS, diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index 6571aa05c61..c9c2f3a3533 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -23,8 +23,63 @@ from ethereum_types.frozen import modify from ethereum_types.numeric import U256, Uint +from ethereum.crypto.hash import keccak256 + from .fork_types import EMPTY_ACCOUNT, Account, Address, Root -from .trie import EMPTY_TRIE_ROOT, Trie, copy_trie, root, trie_get, trie_set +from .trie import ( + EMPTY_TRIE_ROOT, + IncrementalMPT, + Trie, + Witness, + build_mpt, + copy_trie, + mpt_get, + mpt_root, + mpt_set, + root, + trie_get, + trie_set, +) + + +@dataclass +class WitnessState: + """ + Tracks state for deferred execution witness generation. + + Instead of recording witness nodes during execution, we preserve the + pre-block state and track which keys are accessed (reads) and modified + (writes). The witness is generated after all block execution by building + a fresh IncrementalMPT from the pre-block state, traversing read paths, + and applying the final diff for writes. + """ + + # Pre-block state (preserved at enable_witness_mode time) + pre_state_accounts: Dict[Address, Optional[Account]] + pre_state_storages: Dict[Address, Dict[Bytes32, U256]] + + # Dirty tracking during execution (writes) + dirty_accounts: Set[Address] = field(default_factory=set) + dirty_storage: Dict[Address, Set[Bytes32]] = field(default_factory=dict) + + # Access tracking during execution (reads) + accessed_accounts: Set[Address] = field(default_factory=set) + accessed_storage: Dict[Address, Set[Bytes32]] = field(default_factory=dict) + + # Bytecode tracking (code_hash -> bytecode) + accessed_bytecodes: Dict[Bytes32, Bytes] = field(default_factory=dict) + + # Ancestor tracking - oldest block accessed via BLOCKHASH + # All headers from this block to parent are needed for chain validation + oldest_accessed_block: Optional[Uint] = None + + # Block metadata for ancestor collection + current_block_number: Uint = field(default_factory=lambda: Uint(0)) + block_headers: List[Bytes] = field(default_factory=list) + + # Pre-execution MPTs + _main_mpt: Optional["IncrementalMPT"] = None + _storage_mpts: Optional[Dict[Address, "IncrementalMPT"]] = None @dataclass @@ -46,6 +101,7 @@ class State: ] ] = field(default_factory=list) created_accounts: Set[Address] = field(default_factory=set) + _witness_state: Optional[WitnessState] = None @dataclass @@ -70,6 +126,7 @@ def close_state(state: State) -> None: del state._storage_tries del state._snapshots del state.created_accounts + del state._witness_state def begin_transaction( @@ -189,8 +246,11 @@ def get_account_optional(state: State, address: Address) -> Optional[Account]: Account at address. """ - account = trie_get(state._main_trie, address) - return account + # Track accessed account for execution witness generation + if state._witness_state is not None: + state._witness_state.accessed_accounts.add(address) + + return trie_get(state._main_trie, address) def set_account( @@ -212,6 +272,10 @@ def set_account( """ trie_set(state._main_trie, address, account) + # Track dirty account for deferred witness generation + if state._witness_state is not None: + state._witness_state.dirty_accounts.add(address) + def destroy_account(state: State, address: Address) -> None: """ @@ -233,6 +297,82 @@ def destroy_account(state: State, address: Address) -> None: set_account(state, address, None) +def track_bytecode_access(state: State, code: Bytes) -> None: + """ + Track bytecode access for execution witness generation. + + Should be called when bytecode is accessed for execution purposes + (CALL variants, EXTCODESIZE, EXTCODECOPY, system contracts). + + Parameters + ---------- + state : State + The state with optional witness tracking. + code : Bytes + The bytecode being accessed. + + """ + # Skip if witness mode disabled or empty bytecode (EOAs) + if state._witness_state is None or len(code) == 0: + return + + # Compute hash and store for deduplication + code_hash = Bytes32(keccak256(code)) + if code_hash not in state._witness_state.accessed_bytecodes: + state._witness_state.accessed_bytecodes[code_hash] = code + + +def track_block_hash_access(state: State, block_number: Uint) -> None: + """ + Track a block hash access for execution witness generation. + + Called when BLOCKHASH opcode or system contracts access a block hash. + Tracks the oldest block accessed since all headers from that block + to the parent are needed for chain validation. + + Parameters + ---------- + state : State + The state with optional witness tracking. + block_number : Uint + The block number being accessed. + + """ + if state._witness_state is None: + return + + ws = state._witness_state + is_oldest = ( + ws.oldest_accessed_block is None + or block_number < ws.oldest_accessed_block + ) + if is_oldest: + ws.oldest_accessed_block = block_number + + +def set_witness_metadata( + state: State, current_block_number: Uint, block_headers: List[Bytes] +) -> None: + """ + Set block metadata needed for ancestor collection in witness generation. + + Parameters + ---------- + state : State + The state with witness tracking enabled. + current_block_number : Uint + The current block number being executed. + block_headers : List[Bytes] + RLP-encoded headers of previous blocks (up to 256). + + """ + if state._witness_state is None: + return + + state._witness_state.current_block_number = current_block_number + state._witness_state.block_headers = block_headers + + def destroy_storage(state: State, address: Address) -> None: """ Completely remove the storage at `address`. @@ -245,6 +385,15 @@ def destroy_storage(state: State, address: Address) -> None: Address of account whose storage is to be deleted. """ + # Track all pre-block storage keys as dirty for witness generation + if state._witness_state is not None: + ws = state._witness_state + if address in ws.pre_state_storages: + # Mark all pre-block storage keys as dirty (they're now deleted) + ws.dirty_storage.setdefault(address, set()).update( + ws.pre_state_storages[address].keys() + ) + if address in state._storage_tries: del state._storage_tries[address] @@ -291,12 +440,15 @@ def get_storage(state: State, address: Address, key: Bytes32) -> U256: Value at the key. """ + # Track accessed storage for execution witness generation + if state._witness_state is not None: + ws = state._witness_state + ws.accessed_storage.setdefault(address, set()).add(key) + trie = state._storage_tries.get(address) if trie is None: return U256(0) - value = trie_get(trie, key) - assert isinstance(value, U256) return value @@ -330,6 +482,10 @@ def set_storage( if trie._data == {}: del state._storage_tries[address] + # Track dirty storage for witness generation + if state._witness_state is not None: + state._witness_state.dirty_storage.setdefault(address, set()).add(key) + def storage_root(state: State, address: Address) -> Root: """ @@ -375,7 +531,20 @@ def state_root(state: State) -> Root: def get_storage_root(address: Address) -> Root: return storage_root(state, address) - return root(state._main_trie, get_storage_root=get_storage_root) + # Calculate root using patricialize (existing implementation) + patricialize_root = root( + state._main_trie, get_storage_root=get_storage_root + ) + + # If witness mode is enabled, verify IncrementalMPT produces same root + if state._witness_state is not None: + inc_root = incremental_state_root(state) + assert patricialize_root == inc_root, ( + f"Root mismatch! patricialize={patricialize_root.hex()} " + f"incremental={inc_root.hex()}" + ) + + return patricialize_root def account_exists(state: State, address: Address) -> bool: @@ -665,3 +834,265 @@ def set_transient_storage( trie_set(trie, key, value) if trie._data == {}: del transient_storage._tries[address] + + +def enable_witness_mode(state: State) -> None: + """ + Enable witness tracking mode for the state. + + Preserves the current (pre-block) state and sets up dirty tracking. + The actual witness generation is deferred until after block execution. + + Parameters + ---------- + state : + The state to enable witness mode on. + + """ + assert not state._snapshots, "Cannot enable witness during transaction" + + state._witness_state = WitnessState( + pre_state_accounts=dict(state._main_trie._data), + pre_state_storages={ + addr: dict(trie._data) + for addr, trie in state._storage_tries.items() + }, + ) + + +def is_witness_mode_enabled(state: State) -> bool: + """ + Check if witness tracking mode is enabled. + + Parameters + ---------- + state : + The state to check. + + Returns + ------- + enabled : `bool` + True if witness mode is enabled. + + """ + return state._witness_state is not None + + +def _build_witness_mpts(state: State) -> None: + """ + Build and cache the IncrementalMPTs for witness generation. + + This builds the MPTs from pre-block state, applies all reads and writes, + which records witness nodes as a side effect. The MPTs are cached in + WitnessState for reuse by root computation and witness extraction. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + """ + assert state._witness_state is not None + ws = state._witness_state + + # Already built + if ws._main_mpt is not None: + return + + # Build pre-block storage MPTs + storage_mpts: Dict[Address, IncrementalMPT[Bytes32, U256]] = {} + for address, data in ws.pre_state_storages.items(): + storage_mpts[address] = build_mpt( + dict(data), secured=True, default=U256(0) + ) + + def get_pre_storage_root(address: Address) -> Root: + if address in storage_mpts: + return mpt_root(storage_mpts[address]) + return EMPTY_TRIE_ROOT + + main_mpt = build_mpt( + dict(ws.pre_state_accounts), + secured=True, + default=None, + get_storage_root=get_pre_storage_root, + ) + + # 1. Do read-only storages accesses + for address, accessed_keys in ws.accessed_storage.items(): + if address not in storage_mpts: + continue + + for key in accessed_keys: + mpt_get(storage_mpts[address], key) + + # 2. Apply dirty storage to storages (writes) + for address, dirty_keys in ws.dirty_storage.items(): + if address not in storage_mpts: + # New storage created during block + storage_mpts[address] = build_mpt( + {}, secured=True, default=U256(0) + ) + + storage_trie = state._storage_tries.get(address) + # We do two passes to ensure deletions are processed after + # inserts/updates to minimize the number of nodes touched + # in the MPT. + # First pass: inserts and updates + for key in dirty_keys: + value = trie_get(storage_trie, key) if storage_trie else U256(0) + if value != 0: + mpt_set(storage_mpts[address], key, value) + # Second pass: deletions + for key in dirty_keys: + value = trie_get(storage_trie, key) if storage_trie else U256(0) + if value == 0: + mpt_set(storage_mpts[address], key, value) + + # Accounts are "dirty" if: + # - Account fields changed (nonce/balance/code) - tracked in dirty_accounts + # - Storage changed (storage root changed) - tracked in dirty_storage + all_dirty_accounts = ws.dirty_accounts | set(ws.dirty_storage.keys()) + + # 3. Traverse accounts that were read + for address in ws.accessed_accounts: + mpt_get(main_mpt, address) + + # 4. Apply dirty accounts + for address in all_dirty_accounts: + # Get new account data from usual trie + account = trie_get(state._main_trie, address) + + # Get storage root for this account + if address in storage_mpts: + addr_storage_root = mpt_root(storage_mpts[address]) + # Verify invariant: MPT root must match state storage root + # (if storage was fully cleared, it won't be in + # state._storage_tries) + if address in state._storage_tries: + assert addr_storage_root == root(state._storage_tries[address]) + else: + assert addr_storage_root == EMPTY_TRIE_ROOT + else: + # Verify invariant: no storage in state either + assert address not in state._storage_tries + addr_storage_root = EMPTY_TRIE_ROOT + + def get_storage_root_fn( + _: Address, sr: Root = addr_storage_root + ) -> Root: + return sr + + mpt_set( + main_mpt, + address, + account, + get_storage_root=get_storage_root_fn, + ) + + # Cache the built MPTs + ws._main_mpt = main_mpt + ws._storage_mpts = storage_mpts + + +def incremental_state_root(state: State) -> Root: + """ + Compute state root using IncrementalMPT. + + Builds the MPTs if not already built, then returns the root. + The MPT nodes cache their hashes, so subsequent calls are fast. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + root : `Root` + The state root computed via IncrementalMPT. + + """ + assert state._witness_state is not None + _build_witness_mpts(state) + assert state._witness_state._main_mpt is not None + assert state._witness_state._storage_mpts is not None + return mpt_root(state._witness_state._main_mpt) + + +def generate_witness(state: State) -> Tuple[Root, Witness]: + """ + Generate execution witness. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + root : `Root` + The state root computed via IncrementalMPT. + witness : `Witness` + The execution witness containing accessed nodes. + + """ + assert state._witness_state is not None + ws = state._witness_state + + # Ensure MPTs are built + _build_witness_mpts(state) + assert ws._main_mpt is not None + assert ws._storage_mpts is not None + + main_mpt = ws._main_mpt + storage_mpts = ws._storage_mpts + + # Collect ancestors from oldest accessed block to parent (inclusive) + # All headers in this range needed for parent hash chain validation + ancestors: List[Bytes] = [] + assert ws.oldest_accessed_block is not None and ws.block_headers + # Include all headers from oldest accessed to parent (block_number - 1) + for block_num in range( + int(ws.oldest_accessed_block), int(ws.current_block_number) + ): + offset = int(ws.current_block_number) - block_num + if offset <= len(ws.block_headers): + header_rlp = ws.block_headers[-offset] + if header_rlp: + ancestors.append(header_rlp) + + # Collect witness from all MPTs + witness = Witness( + accessed_nodes=dict(main_mpt.witness.accessed_nodes), + accessed_keys=set(main_mpt.witness.accessed_keys), + bytecodes=sorted(ws.accessed_bytecodes.values()), + ancestors=ancestors, + ) + for mpt in storage_mpts.values(): + witness.accessed_nodes.update(mpt.witness.accessed_nodes) + witness.accessed_keys.update(mpt.witness.accessed_keys) + + return mpt_root(main_mpt), witness + + +def get_witness(state: State) -> Witness: + """ + Get the collected witness data from the state. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + witness : `Witness` + The witness data containing accessed nodes. + + """ + if state._witness_state is None: + return Witness() + + _, witness = generate_witness(state) + return witness diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index fea8e0ece48..ab7820fb7a5 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -23,8 +23,10 @@ MutableMapping, Optional, Sequence, + Set, Tuple, TypeVar, + Union, cast, ) @@ -135,6 +137,70 @@ class BranchNode: InternalNode = LeafNode | ExtensionNode | BranchNode +# Mutable node types for incremental MPT updates +@dataclass +class MutableLeafNode: + """Mutable leaf node in the Merkle Trie for in-place updates.""" + + rest_of_key: Bytes + value: Bytes + _hash: Optional[Bytes] = None # Cached hash, invalidated on change + _rlp: Optional[Bytes] = None # Cached RLP encoding + + +@dataclass +class MutableExtensionNode: + """Mutable extension node in the Merkle Trie for in-place updates.""" + + key_segment: Bytes + child: "MutableNode" + _hash: Optional[Bytes] = None + _rlp: Optional[Bytes] = None + + +@dataclass +class MutableBranchNode: + """Mutable branch node in the Merkle Trie for in-place updates.""" + + children: List[Optional["MutableNode"]] # 16 children slots + value: Bytes # Value if key terminates at this branch + _hash: Optional[Bytes] = None + _rlp: Optional[Bytes] = None + + +MutableNode = Union[ + MutableLeafNode, MutableExtensionNode, MutableBranchNode, None +] + + +@dataclass +class Witness: + """Tracks nodes accessed during trie operations for witness generation.""" + + accessed_nodes: Dict[Bytes, Bytes] = field( + default_factory=dict + ) # hash -> RLP encoding + accessed_keys: Set[Bytes] = field(default_factory=set) # Original keys + bytecodes: List[Bytes] = field(default_factory=list) # Accessed bytecodes + ancestors: List[Bytes] = field(default_factory=list) # RLP-encoded headers + + +@dataclass +class IncrementalMPT(Generic[K, V]): + """ + An MPT that supports incremental updates and witness tracking. + + This maintains an actual tree structure that can be updated in-place, + rather than rebuilding the entire tree on each root calculation. + """ + + secured: bool + default: V + root_node: MutableNode = None + witness: Witness = field(default_factory=Witness) + _data: Dict[K, V] = field(default_factory=dict) # For backward compat + + def encode_internal_node(node: Optional[InternalNode]) -> Extended: """ Encodes a Merkle Trie node into its RLP form. The RLP will then be @@ -354,18 +420,21 @@ def bytes_to_nibble_list(bytes_: Bytes) -> Bytes: return Bytes(nibble_list) -def _prepare_trie( - trie: Trie[K, V], +def _prepare_data( + data: Mapping[K, V], + secured: bool, get_storage_root: Optional[Callable[[Address], Root]] = None, ) -> Mapping[Bytes, Bytes]: """ - Prepares the trie for root calculation. Removes values that are empty, + Prepares data for trie root calculation. Removes values that are empty, hashes the keys (if `secured == True`) and encodes all the nodes. Parameters ---------- - trie : - The `Trie` to prepare. + data : + The key-value data to prepare. + secured : + Whether keys should be hashed. get_storage_root : Function to get the storage root of an account. Needed to encode `Account` objects. @@ -378,7 +447,7 @@ def _prepare_trie( """ mapped: MutableMapping[Bytes, Bytes] = {} - for preimage, value in trie._data.items(): + for preimage, value in data.items(): if isinstance(value, Account): assert get_storage_root is not None address = Address(preimage) @@ -388,7 +457,7 @@ def _prepare_trie( if encoded_value == b"": raise AssertionError key: Bytes - if trie.secured: + if secured: # "secure" tries hash keys once before construction key = keccak256(preimage) else: @@ -398,6 +467,31 @@ def _prepare_trie( return mapped +def _prepare_trie( + trie: Trie[K, V], + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> Mapping[Bytes, Bytes]: + """ + Prepares the trie for root calculation. Removes values that are empty, + hashes the keys (if `secured == True`) and encodes all the nodes. + + Parameters + ---------- + trie : + The `Trie` to prepare. + get_storage_root : + Function to get the storage root of an account. Needed to encode + `Account` objects. + + Returns + ------- + out : `Mapping[ethereum.base_types.Bytes, Node]` + Object with keys mapped to nibble-byte form. + + """ + return _prepare_data(trie._data, trie.secured, get_storage_root) + + def root( trie: Trie[K, V], get_storage_root: Optional[Callable[[Address], Root]] = None, @@ -506,3 +600,679 @@ def patricialize( cast(BranchSubnodes, assert_type(subnodes, Tuple[Extended, ...])), value, ) + + +def _build_mutable_tree( + obj: Mapping[Bytes, Bytes], level: Uint +) -> MutableNode: + """ + Build a mutable tree structure from a prepared key-value mapping. + + This is similar to `patricialize()` but creates mutable nodes for + in-place updates. + + Parameters + ---------- + obj : + Underlying trie key-value pairs, with keys in nibble-list format. + level : + Current trie level. + + Returns + ------- + node : `MutableNode` + Root node of the mutable tree. + + """ + if len(obj) == 0: + return None + + arbitrary_key = next(iter(obj)) + + # Leaf node case + if len(obj) == 1: + return MutableLeafNode( + rest_of_key=arbitrary_key[level:], + value=obj[arbitrary_key], + ) + + # Check for common prefix (extension node) + substring = arbitrary_key[level:] + prefix_length = len(substring) + for key in obj: + prefix_length = min( + prefix_length, common_prefix_length(substring, key[level:]) + ) + if prefix_length == 0: + break + + if prefix_length > 0: + prefix = arbitrary_key[int(level) : int(level) + prefix_length] + child = _build_mutable_tree(obj, level + Uint(prefix_length)) + return MutableExtensionNode(key_segment=prefix, child=child) + + # Branch node case + branches: List[MutableMapping[Bytes, Bytes]] = [{} for _ in range(16)] + value = b"" + + for key in obj: + if len(key) == level: + value = obj[key] + else: + branches[key[level]][key] = obj[key] + + children: List[Optional[MutableNode]] = [ + _build_mutable_tree(branches[k], level + Uint(1)) for k in range(16) + ] + + return MutableBranchNode(children=children, value=value) + + +def build_mpt( + data: Mapping[K, V], + secured: bool, + default: V, + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> IncrementalMPT[K, V]: + """ + Build an IncrementalMPT from key-value data. + + This is called with the pre-execution state to create a mutable + tree structure that can be updated in-place during execution. + + Parameters + ---------- + data : + The source key-value data to build from. + secured : + Whether to hash keys before insertion. + default : + Default value for missing keys. + get_storage_root : + Function to get the storage root of an account. + + Returns + ------- + mpt : `IncrementalMPT[K, V]` + An incremental MPT with the same data. + + """ + prepared = _prepare_data(data, secured, get_storage_root) + root_node = _build_mutable_tree(prepared, Uint(0)) + + return IncrementalMPT( + secured=secured, + default=default, + root_node=root_node, + _data=dict(data), + ) + + +def _invalidate_hash(node: MutableNode) -> None: + """Invalidate the cached hash of a node.""" + if node is not None: + node._hash = None + node._rlp = None + + +def _record_witness( + witness: Witness, node: MutableNode, key: Optional[Bytes] = None +) -> None: + """Record a node access in the witness.""" + if node is None: + return + + # Record the key if provided + if key is not None: + witness.accessed_keys.add(key) + + # Compute hash and RLP if not cached + node_hash, node_rlp = _compute_node_hash_and_rlp(node) + if node_hash is not None and node_hash not in witness.accessed_nodes: + witness.accessed_nodes[node_hash] = node_rlp + + +def _encode_mutable_node(node: MutableNode) -> Extended: + """ + Encode a mutable node to its RLP form (unencoded tuple or hash). + + Similar to encode_internal_node but for mutable nodes. + """ + if node is None: + return b"" + elif isinstance(node, MutableLeafNode): + return ( + nibble_list_to_compact(node.rest_of_key, True), + node.value, + ) + elif isinstance(node, MutableExtensionNode): + child_encoded = _encode_mutable_node_to_extended(node.child) + return ( + nibble_list_to_compact(node.key_segment, False), + child_encoded, + ) + elif isinstance(node, MutableBranchNode): + children_encoded = [ + _encode_mutable_node_to_extended(child) for child in node.children + ] + return children_encoded + [node.value] + else: + raise AssertionError(f"Invalid mutable node type {type(node)}!") + + +def _encode_mutable_node_to_extended(node: MutableNode) -> Extended: + """ + Encode a mutable node for embedding in parent. + + Returns the hash if RLP >= 32 bytes, otherwise returns unencoded form. + """ + if node is None: + return b"" + + unencoded = _encode_mutable_node(node) + encoded = rlp.encode(unencoded) + + if len(encoded) < 32: + return unencoded + else: + return keccak256(encoded) + + +def _compute_node_hash_and_rlp( + node: MutableNode, +) -> Tuple[Optional[Bytes], Bytes]: + """ + Compute the hash and RLP encoding of a node. + + Returns (hash, rlp) where hash may be None for small nodes. + """ + if node is None: + return None, b"" + + # Use cached values if available + if node._rlp is not None: + if node._hash is not None: + return node._hash, node._rlp + elif len(node._rlp) >= 32: + return keccak256(node._rlp), node._rlp + + unencoded = _encode_mutable_node(node) + encoded = rlp.encode(unencoded) + + # Cache the RLP + node._rlp = encoded + + if len(encoded) >= 32: + node._hash = keccak256(encoded) + return node._hash, encoded + else: + return None, encoded + + +def mpt_get(mpt: IncrementalMPT[K, V], key: K) -> V: + """ + Get a value from the incremental MPT. + + Traverses the tree and records accessed nodes in the witness for + execution witness generation. + + Parameters + ---------- + mpt : + The incremental MPT to get from. + key : + Key to lookup. + + Returns + ------- + value : `V` + Value at the key, or the default value if not found. + + """ + # Get from flat data for consistency + value = mpt._data.get(key, mpt.default) + + # Record the key access + if mpt.secured: + nibble_key = bytes_to_nibble_list(keccak256(key)) + else: + nibble_key = bytes_to_nibble_list(key) + + mpt.witness.accessed_keys.add(key) + + # Traverse tree and record witness nodes + _mpt_traverse_for_witness(mpt, mpt.root_node, nibble_key, Uint(0)) + + return value + + +def _mpt_traverse_for_witness( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + level: Uint, +) -> None: + """Traverse the tree recording nodes in the witness.""" + if node is None: + return + + _record_witness(mpt.witness, node) + + if isinstance(node, MutableLeafNode): + # Leaf node - end of path + pass + elif isinstance(node, MutableExtensionNode): + # Extension node - follow if key matches + segment_len = len(node.key_segment) + lvl = int(level) + if key[lvl : lvl + segment_len] == node.key_segment: + _mpt_traverse_for_witness( + mpt, node.child, key, Uint(lvl + segment_len) + ) + elif isinstance(node, MutableBranchNode): + # Branch node - follow appropriate child + lvl = int(level) + if lvl < len(key): + child_idx = key[lvl] + _mpt_traverse_for_witness( + mpt, node.children[child_idx], key, Uint(lvl + 1) + ) + + +def mpt_set( + mpt: IncrementalMPT[K, V], + key: K, + value: V, + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> None: + """ + Set a value in the incremental MPT. + + Updates the tree in-place and invalidates cached hashes along the path. + + Parameters + ---------- + mpt : + The incremental MPT to update. + key : + Key to set. + value : + Value to set at the key. + get_storage_root : + Function to get storage root (for Account values). + + """ + # Update flat data for backward compatibility + if value == mpt.default: + if key in mpt._data: + del mpt._data[key] + else: + mpt._data[key] = value + + # Prepare key and value + if mpt.secured: + nibble_key = bytes_to_nibble_list(keccak256(key)) + else: + nibble_key = bytes_to_nibble_list(key) + + # Encode the value + if value == mpt.default: + encoded_value = b"" + elif isinstance(value, Account): + assert get_storage_root is not None + address = Address(key) + encoded_value = encode_node(value, get_storage_root(address)) + else: + encoded_value = encode_node(value) + + # Update tree + if encoded_value == b"": + # Delete operation + mpt.root_node = _mpt_delete_node( + mpt, mpt.root_node, nibble_key, Uint(0) + ) + else: + # Insert/update operation + mpt.root_node = _mpt_insert_node( + mpt, mpt.root_node, nibble_key, encoded_value, Uint(0) + ) + + +def _mpt_insert_node( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """ + Insert or update a value in the mutable tree. + + Returns the new/updated node for this position. + """ + _record_witness(mpt.witness, node) + + if node is None: + # Empty slot - create new leaf + return MutableLeafNode(rest_of_key=key[level:], value=value) + + _invalidate_hash(node) + + if isinstance(node, MutableLeafNode): + return _insert_into_leaf(mpt, node, key, value, level) + elif isinstance(node, MutableExtensionNode): + return _insert_into_extension(mpt, node, key, value, level) + elif isinstance(node, MutableBranchNode): + return _insert_into_branch(mpt, node, key, value, level) + else: + raise AssertionError(f"Invalid node type {type(node)}") + + +def _insert_into_leaf( + _mpt: IncrementalMPT, + node: MutableLeafNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is a leaf.""" + existing_key = node.rest_of_key + remaining_key = key[level:] + + if existing_key == remaining_key: + # Same key - update value + node.value = value + return node + + # Keys differ - need to create branch + prefix_len = common_prefix_length(existing_key, remaining_key) + + # Create new branch or extension + branch + if prefix_len > 0: + # Common prefix - create extension then branch + branch = _create_branch_from_two_leaves( + existing_key[prefix_len:], + node.value, + remaining_key[prefix_len:], + value, + ) + return MutableExtensionNode( + key_segment=existing_key[:prefix_len], child=branch + ) + else: + # No common prefix - create branch directly + return _create_branch_from_two_leaves( + existing_key, node.value, remaining_key, value + ) + + +def _create_branch_from_two_leaves( + key1: Bytes, value1: Bytes, key2: Bytes, value2: Bytes +) -> MutableBranchNode: + """Create a branch node from two key-value pairs.""" + children: List[Optional[MutableNode]] = [None] * 16 + branch_value = b"" + + if len(key1) == 0: + branch_value = value1 + else: + idx1 = key1[0] + children[idx1] = MutableLeafNode(rest_of_key=key1[1:], value=value1) + + if len(key2) == 0: + branch_value = value2 + else: + idx2 = key2[0] + children[idx2] = MutableLeafNode(rest_of_key=key2[1:], value=value2) + + return MutableBranchNode(children=children, value=branch_value) + + +def _insert_into_extension( + mpt: IncrementalMPT, + node: MutableExtensionNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is an extension.""" + remaining_key = key[level:] + segment = node.key_segment + prefix_len = common_prefix_length(segment, remaining_key) + + if prefix_len == len(segment): + # Key follows extension completely - recurse into child + node.child = _mpt_insert_node( + mpt, node.child, key, value, level + Uint(prefix_len) + ) + return node + + # Extension needs to be split + if prefix_len > 0: + # Partial match - create new extension for common prefix + new_child = _split_extension(node, remaining_key, value, prefix_len) + return MutableExtensionNode( + key_segment=segment[:prefix_len], child=new_child + ) + else: + # No common prefix - create branch at this level + return _split_extension(node, remaining_key, value, 0) + + +def _split_extension( + node: MutableExtensionNode, + remaining_key: Bytes, + value: Bytes, + prefix_len: int, +) -> MutableNode: + """Split an extension node when keys diverge.""" + segment = node.key_segment + children: List[Optional[MutableNode]] = [None] * 16 + branch_value = b"" + + # Place existing extension's child + segment_after_prefix = segment[prefix_len:] + if len(segment_after_prefix) == 1: + # Single nibble left - place child directly in branch + idx = segment_after_prefix[0] + children[idx] = node.child + elif len(segment_after_prefix) > 1: + # Multiple nibbles - create new extension + idx = segment_after_prefix[0] + children[idx] = MutableExtensionNode( + key_segment=segment_after_prefix[1:], child=node.child + ) + + # Place new value + key_after_prefix = remaining_key[prefix_len:] + if len(key_after_prefix) == 0: + branch_value = value + else: + idx = key_after_prefix[0] + if children[idx] is None: + children[idx] = MutableLeafNode( + rest_of_key=key_after_prefix[1:], value=value + ) + else: + # Need to merge with existing child (shouldn't happen normally) + raise AssertionError("Unexpected collision during split") + + return MutableBranchNode(children=children, value=branch_value) + + +def _insert_into_branch( + mpt: IncrementalMPT, + node: MutableBranchNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is a branch.""" + remaining_key = key[level:] + + if len(remaining_key) == 0: + # Value terminates at this branch + node.value = value + return node + + # Recurse into appropriate child + child_idx = remaining_key[0] + node.children[child_idx] = _mpt_insert_node( + mpt, node.children[child_idx], key, value, level + Uint(1) + ) + return node + + +def _mpt_delete_node( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """ + Delete a key from the mutable tree. + + Returns the updated node (may be different type or None). + """ + _record_witness(mpt.witness, node) + + if node is None: + return None + + _invalidate_hash(node) + + if isinstance(node, MutableLeafNode): + if node.rest_of_key == key[level:]: + return None # Key found, delete + return node # Key not found, no change + elif isinstance(node, MutableExtensionNode): + return _delete_from_extension(mpt, node, key, level) + elif isinstance(node, MutableBranchNode): + return _delete_from_branch(mpt, node, key, level) + else: + raise AssertionError(f"Invalid node type {type(node)}") + + +def _delete_from_extension( + mpt: IncrementalMPT, + node: MutableExtensionNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """Handle deletion when current node is an extension.""" + segment = node.key_segment + remaining_key = key[level:] + prefix_len = common_prefix_length(segment, remaining_key) + + if prefix_len < len(segment): + return node # Key doesn't follow this extension + + # Recurse into child + new_child = _mpt_delete_node( + mpt, node.child, key, level + Uint(len(segment)) + ) + + if new_child is None: + return None + + # Collapse if child is now an extension + if isinstance(new_child, MutableExtensionNode): + return MutableExtensionNode( + key_segment=segment + new_child.key_segment, + child=new_child.child, + ) + elif isinstance(new_child, MutableLeafNode): + # Merge extension into leaf + return MutableLeafNode( + rest_of_key=segment + new_child.rest_of_key, + value=new_child.value, + ) + + node.child = new_child + return node + + +def _delete_from_branch( + mpt: IncrementalMPT, + node: MutableBranchNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """Handle deletion when current node is a branch.""" + remaining_key = key[level:] + + if len(remaining_key) == 0: + # Delete value at this branch + node.value = b"" + else: + # Delete from child + child_idx = remaining_key[0] + node.children[child_idx] = _mpt_delete_node( + mpt, node.children[child_idx], key, level + Uint(1) + ) + + # Check if branch can be collapsed + return _collapse_branch(mpt, node) + + +def _collapse_branch( + mpt: IncrementalMPT, node: MutableBranchNode +) -> MutableNode: + """Collapse a branch node if it has only one child and no value.""" + non_empty = [(i, c) for i, c in enumerate(node.children) if c is not None] + + if len(non_empty) == 0 and node.value == b"": + return None + + if len(non_empty) == 1 and node.value == b"": + idx, child = non_empty[0] + _record_witness(mpt.witness, child) # Record the surviving child + nibble = Bytes([idx]) + + if isinstance(child, MutableLeafNode): + return MutableLeafNode( + rest_of_key=nibble + child.rest_of_key, + value=child.value, + ) + elif isinstance(child, MutableExtensionNode): + return MutableExtensionNode( + key_segment=nibble + child.key_segment, + child=child.child, + ) + else: + # Child is a branch - create extension + return MutableExtensionNode(key_segment=nibble, child=child) + + if len(non_empty) == 0 and node.value != b"": + # Only value at this branch - convert to leaf + return MutableLeafNode(rest_of_key=b"", value=node.value) + + return node + + +def mpt_root(mpt: IncrementalMPT) -> Root: + """ + Compute the root hash of the incremental MPT. + + Uses cached hashes where available for efficiency. + + Parameters + ---------- + mpt : + The incremental MPT. + + Returns + ------- + root : `Root` + The MPT root hash. + + """ + if mpt.root_node is None: + return EMPTY_TRIE_ROOT + + root_encoded = _encode_mutable_node_to_extended(mpt.root_node) + + if isinstance(root_encoded, Bytes): + return Root(root_encoded) + else: + return keccak256(rlp.encode(root_encoded)) diff --git a/src/ethereum/forks/osaka/utils/message.py b/src/ethereum/forks/osaka/utils/message.py index ecdc3143234..efd79d8eef9 100644 --- a/src/ethereum/forks/osaka/utils/message.py +++ b/src/ethereum/forks/osaka/utils/message.py @@ -16,7 +16,7 @@ from ethereum_types.numeric import Uint from ..fork_types import Address -from ..state import get_account +from ..state import get_account, track_bytecode_access from ..transactions import Transaction from ..vm import BlockEnvironment, Message, TransactionEnvironment from ..vm.precompiled_contracts.mapping import PRE_COMPILED_CONTRACTS @@ -63,6 +63,7 @@ def prepare_message( current_target = tx.to msg_data = tx.data code = get_account(block_env.state, tx.to).code + track_bytecode_access(block_env.state, code) code_address = tx.to else: raise AssertionError("Target must be address or empty bytes") diff --git a/src/ethereum/forks/osaka/vm/eoa_delegation.py b/src/ethereum/forks/osaka/vm/eoa_delegation.py index e6dd0f12011..8ad754cbe06 100644 --- a/src/ethereum/forks/osaka/vm/eoa_delegation.py +++ b/src/ethereum/forks/osaka/vm/eoa_delegation.py @@ -13,7 +13,13 @@ from ethereum.exceptions import InvalidBlock, InvalidSignatureError from ..fork_types import Address, Authorization -from ..state import account_exists, get_account, increment_nonce, set_code +from ..state import ( + account_exists, + get_account, + increment_nonce, + set_code, + track_bytecode_access, +) from ..utils.hexadecimal import hex_to_address from ..vm.gas import GAS_COLD_ACCOUNT_ACCESS, GAS_WARM_ACCESS from . import Evm, Message @@ -136,6 +142,7 @@ def access_delegation( state = evm.message.block_env.state code = get_account(state, address).code + track_bytecode_access(state, code) if not is_valid_delegation(code): return False, address, code, Uint(0) @@ -146,6 +153,7 @@ def access_delegation( evm.accessed_addresses.add(address) access_gas_cost = GAS_COLD_ACCOUNT_ACCESS code = get_account(state, address).code + track_bytecode_access(state, code) return True, address, code, access_gas_cost diff --git a/src/ethereum/forks/osaka/vm/instructions/block.py b/src/ethereum/forks/osaka/vm/instructions/block.py index 43be9e58e23..78f35a0ff48 100644 --- a/src/ethereum/forks/osaka/vm/instructions/block.py +++ b/src/ethereum/forks/osaka/vm/instructions/block.py @@ -13,6 +13,7 @@ from ethereum_types.numeric import U256, Uint +from ...state import track_block_hash_access from .. import Evm from ..gas import GAS_BASE, GAS_BLOCK_HASH, charge_gas from ..stack import pop, push @@ -57,6 +58,8 @@ def block_hash(evm: Evm) -> None: current_block_hash = evm.message.block_env.block_hashes[ -(current_block_number - block_number) ] + # Track access for witness generation + track_block_hash_access(evm.message.block_env.state, block_number) push(evm.stack, U256.from_be_bytes(current_block_hash)) diff --git a/src/ethereum/forks/osaka/vm/instructions/environment.py b/src/ethereum/forks/osaka/vm/instructions/environment.py index 28c595ee514..ba861ed5e1b 100644 --- a/src/ethereum/forks/osaka/vm/instructions/environment.py +++ b/src/ethereum/forks/osaka/vm/instructions/environment.py @@ -18,7 +18,7 @@ from ethereum.utils.numeric import ceil32 from ...fork_types import EMPTY_ACCOUNT -from ...state import get_account +from ...state import get_account, track_bytecode_access from ...utils.address import to_address_masked from ...vm.memory import buffer_read, memory_write from .. import Evm @@ -351,6 +351,7 @@ def extcodesize(evm: Evm) -> None: # OPERATION code = get_account(evm.message.block_env.state, address).code + track_bytecode_access(evm.message.block_env.state, code) codesize = U256(len(code)) push(evm.stack, codesize) @@ -393,6 +394,7 @@ def extcodecopy(evm: Evm) -> None: # OPERATION evm.memory += b"\x00" * extend_memory.expand_by code = get_account(evm.message.block_env.state, address).code + track_bytecode_access(evm.message.block_env.state, code) value = buffer_read(code, code_start_index, size) memory_write(evm.memory, memory_start_index, value) diff --git a/src/ethereum/forks/osaka/vm/interpreter.py b/src/ethereum/forks/osaka/vm/interpreter.py index 2d5eb64aff9..43ac1521c45 100644 --- a/src/ethereum/forks/osaka/vm/interpreter.py +++ b/src/ethereum/forks/osaka/vm/interpreter.py @@ -43,6 +43,7 @@ move_ether, rollback_transaction, set_code, + track_bytecode_access, ) from ..vm import Message from ..vm.eoa_delegation import get_delegated_code_address, set_delegation @@ -131,6 +132,7 @@ def process_message_call(message: Message) -> MessageCallOutput: message.disable_precompiles = True message.accessed_addresses.add(delegated_address) message.code = get_account(block_env.state, delegated_address).code + track_bytecode_access(block_env.state, message.code) message.code_address = delegated_address evm = process_message(message) diff --git a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py index 9a14efa54ca..7a84c5a859f 100644 --- a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py +++ b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py @@ -303,6 +303,26 @@ def close_state(self) -> Any: """close_state function of the fork.""" return self._module("state").close_state + @property + def enable_witness_mode(self) -> Any: + """enable_witness_mode function of the fork (Osaka+ only).""" + return self._module("state").enable_witness_mode + + @property + def get_witness(self) -> Any: + """get_witness function of the fork (Osaka+ only).""" + return self._module("state").get_witness + + @property + def set_witness_metadata(self) -> Any: + """set_witness_metadata function of the fork (Osaka+ only).""" + return self._module("state").set_witness_metadata + + @property + def track_block_hash_access(self) -> Any: + """track_block_hash_access function of the fork (Osaka+ only).""" + return self._module("state").track_block_hash_access + @property def create_ether(self) -> Any: """create_ether function of the fork.""" diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index 4988ef25bcd..3cb55c58b38 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -350,6 +350,17 @@ def pay_block_rewards(self, block_reward: U256, block_env: Any) -> None: block_env.state, ommer.coinbase, ommer_miner_reward ) + def _enable_witness_mode(self, block_env: Any) -> None: + """Enable witness tracking mode if supported by the fork (Osaka+).""" + if hasattr(self.fork, "enable_witness_mode"): + self.fork.enable_witness_mode(block_env.state) + if hasattr(self.fork, "set_witness_metadata"): + self.fork.set_witness_metadata( + block_env.state, + self.env.block_number, + self.env.block_headers, + ) + def run_state_test(self) -> Any: """ Apply a single transaction on pre-state. No system operations @@ -357,6 +368,9 @@ def run_state_test(self) -> Any: """ block_env = self.block_environment() block_output = self.fork.BlockOutput() + + self._enable_witness_mode(block_env) + self.backup_state() if len(self.txs.transactions) > 0: tx = self.txs.transactions[0] @@ -377,6 +391,11 @@ def run_state_test(self) -> Any: def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: if self.fork.has_compute_requests_hash: + # Track parent block access for witness (EIP-2935 system call) + if hasattr(self.fork, "track_block_hash_access"): + self.fork.track_block_hash_access( + block_env.state, block_env.number - Uint(1) + ) self.fork.process_unchecked_system_transaction( block_env=block_env, target_address=self.fork.HISTORY_STORAGE_ADDRESS, @@ -448,6 +467,8 @@ def run_blockchain_test(self) -> None: block_env = self.block_environment() block_output = self.fork.BlockOutput() + self._enable_witness_mode(block_env) + try: self._run_blockchain_test(block_env, block_output) except InvalidBlock as e: diff --git a/src/ethereum_spec_tools/evm_tools/t8n/env.py b/src/ethereum_spec_tools/evm_tools/t8n/env.py index be719ba7af5..5a1adb8dc4f 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/env.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/env.py @@ -47,6 +47,7 @@ class Env: parent_gas_limit: Optional[Uint] parent_base_fee_per_gas: Optional[Uint] block_hashes: Optional[List[Any]] + block_headers: List[bytes] parent_ommers_hash: Optional[Hash32] ommers: Any parent_beacon_block_root: Optional[Hash32] @@ -277,19 +278,31 @@ def read_block_difficulty(self, data: Any, t8n: "T8N") -> None: def read_block_hashes(self, data: Any) -> None: """ - Read the block hashes. Returns a maximum of 256 block hashes. + Read block hashes and headers. Supports both blockHashes and + blockHeaders inputs. If blockHeaders is provided, hashes are + computed from the RLP-encoded headers. """ - # Read the block hashes block_hashes: List[Any] = [] + block_headers: List[bytes] = [] - # The hex key strings provided might not have standard formatting + # Check if blockHeaders is provided (preferred) + clean_block_headers: Dict[int, bytes] = {} clean_block_hashes: Dict[int, Hash32] = {} - if "blockHashes" in data: + + if "blockHeaders" in data: + # Read headers and compute hashes from them + for key, value in data["blockHeaders"].items(): + int_key = int(key, 16) + header_rlp = hex_to_bytes(value) + clean_block_headers[int_key] = header_rlp + clean_block_hashes[int_key] = Hash32(keccak256(header_rlp)) + elif "blockHashes" in data: + # Fall back to blockHashes if no headers provided for key, value in data["blockHashes"].items(): int_key = int(key, 16) clean_block_hashes[int_key] = Hash32(hex_to_bytes(value)) - # Store a maximum of 256 block hashes. + # Store a maximum of 256 block hashes/headers max_blockhash_count = min(Uint(256), self.block_number) for number in range( self.block_number - max_blockhash_count, self.block_number @@ -299,7 +312,13 @@ def read_block_hashes(self, data: Any) -> None: else: block_hashes.append(None) + if number in clean_block_headers.keys(): + block_headers.append(clean_block_headers[number]) + else: + block_headers.append(b"") + self.block_hashes = block_hashes + self.block_headers = block_headers def read_ommers(self, data: Any, t8n: "T8N") -> None: """ diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index 544838a5dbb..a4d8bd2c299 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -270,6 +270,7 @@ class Result: block_exception: Optional[str] = None block_access_list: Optional[Any] = None block_access_list_hash: Optional[Hash32] = None + execution_witness: Optional[Dict[str, List[str]]] = None def get_receipts_from_output( self, @@ -332,6 +333,23 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: block_output.block_access_list ) ) + # Extract execution witness for Osaka+ forks + if hasattr(t8n.fork, "get_witness"): + witness = t8n.fork.get_witness(block_env.state) + self.execution_witness = { + "nodes": sorted( + [ + "0x" + node_rlp.hex() + for node_rlp in witness.accessed_nodes.values() + ] + ), + "bytecodes": [ + "0x" + bytecode.hex() for bytecode in witness.bytecodes + ], + "ancestors": [ + "0x" + header_rlp.hex() for header_rlp in witness.ancestors + ], + } @staticmethod def _block_access_list_to_json(account_changes: Any) -> Any: @@ -476,5 +494,7 @@ def to_json(self) -> Any: data["blockAccessListHash"] = encode_to_hex( self.block_access_list_hash ) + if self.execution_witness is not None: + data["executionWitness"] = self.execution_witness return data diff --git a/whitelist.txt b/whitelist.txt index 2cf28a99287..ca3c50141de 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -1293,6 +1293,7 @@ webSocket wei wfile whitelist +ws wikipedia wordlist Words2