From a021e451758a5c1b5a48209cbe37fa410f3d2be9 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 13:37:41 -0300 Subject: [PATCH 1/8] add incremental mpt and witness tracking efficient witness tracking wire it to the fixture output --- .../pytest_commands/plugins/filler/filler.py | 6 - .../pytest_commands/plugins/filler/witness.py | 143 ---- .../pytest_ini_files/pytest-fill.ini | 1 - .../client_clis/cli_types.py | 2 + .../execution_testing/fixtures/blockchain.py | 27 +- .../src/execution_testing/specs/blockchain.py | 2 + src/ethereum/forks/osaka/fork.py | 10 + src/ethereum/forks/osaka/state.py | 301 ++++++- src/ethereum/forks/osaka/trie.py | 742 ++++++++++++++++++ .../evm_tools/loaders/fork_loader.py | 10 + .../evm_tools/t8n/__init__.py | 9 + .../evm_tools/t8n/t8n_types.py | 16 +- 12 files changed, 1093 insertions(+), 176 deletions(-) delete mode 100644 packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py index e75f01f2279..1b4974702b3 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py +++ b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/filler.py @@ -1306,7 +1306,6 @@ def base_test_parametrizer_func( fixture_source_url: str, gas_benchmark_value: int, fixed_opcode_count: int | None, - witness_generator: Any, ) -> Any: """ Fixture used to instantiate an auto-fillable BaseTest object from @@ -1430,11 +1429,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: _info_metadata=t8n._info_metadata, ) - # Generate witness data if witness functionality is enabled via - # the witness plugin - if witness_generator is not None: - witness_generator(fixture) - fixture_path = fixture_collector.add_fixture( node_to_test_info(request.node), fixture, diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py b/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py deleted file mode 100644 index 00596675919..00000000000 --- a/packages/testing/src/execution_testing/cli/pytest_commands/plugins/filler/witness.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Pytest plugin for witness functionality. - -Provides --witness command-line option that checks for the witness-filler tool -in PATH and generates execution witness data for blockchain test fixtures when -enabled. -""" - -import shutil -import subprocess -from typing import Callable, List - -import pytest - -from execution_testing.base_types import EthereumTestRootModel -from execution_testing.fixtures.blockchain import ( - BlockchainFixture, - FixtureBlock, - WitnessChunk, -) -from execution_testing.forks import Paris - - -class WitnessFillerResult(EthereumTestRootModel[List[WitnessChunk]]): - """ - Model that defines the expected result from the `witness-filler` command. - """ - - root: List[WitnessChunk] - - -class Merge(Paris): - """ - Paris fork that serializes as 'Merge' for witness-filler compatibility. - - IMPORTANT: This class MUST be named 'Merge' (not 'MergeForWitness' or - similar) because the class name is used directly in Pydantic serialization, - and witness-filler expects exactly 'Merge' for this fork. - """ - - pass - - -def pytest_addoption(parser: pytest.Parser) -> None: - """Add witness command-line options to pytest.""" - witness_group = parser.getgroup( - "witness", "Arguments for witness functionality" - ) - witness_group.addoption( - "--witness", - "--witness-the-fitness", - action="store_true", - dest="witness", - default=False, - help=( - "Generate execution witness data for blockchain test fixtures using the " - "witness-filler tool (must be installed separately)." - ), - ) - - -def pytest_configure(config: pytest.Config) -> None: - """ - Pytest hook called after command line options have been parsed. - - If --witness is enabled, checks that the witness-filler tool is available - in PATH. - """ - if config.getoption("witness"): - # Check if witness-filler binary is available in PATH - if not shutil.which("witness-filler"): - pytest.exit( - "witness-filler tool not found in PATH. Please build and install witness-filler " - "from https://github.com/kevaundray/reth.git before using --witness flag.\n" - "Example: cargo install --git https://github.com/kevaundray/reth.git " - "witness-filler", - 1, - ) - - -@pytest.fixture -def witness_generator( - request: pytest.FixtureRequest, -) -> Callable[[BlockchainFixture], None] | None: - """ - Provide a witness generator function if --witness is enabled. - - Returns: None if witness functionality is disabled. Callable that generates - witness data for a BlockchainFixture if enabled. - """ - if not request.config.getoption("witness"): - return None - - def generate_witness(fixture: BlockchainFixture) -> None: - """ - Generate witness data for a blockchain fixture using the witness-filler - tool. - """ - if not isinstance(fixture, BlockchainFixture): - return None - - # Hotfix: witness-filler expects "Merge" but execution-spec-tests uses - # "Paris" - original_fork = None - if fixture.fork is Paris: - original_fork = fixture.fork - fixture.fork = Merge - - try: - result = subprocess.run( - ["witness-filler"], - input=fixture.model_dump_json(by_alias=True), - text=True, - capture_output=True, - ) - finally: - if original_fork is not None: - fixture.fork = original_fork - - if result.returncode != 0: - raise RuntimeError( - f"witness-filler tool failed with exit code {result.returncode}. " - f"stderr: {result.stderr}" - ) - - try: - result_model = WitnessFillerResult.model_validate_json( - result.stdout - ) - witnesses = result_model.root - - for i, witness in enumerate(witnesses): - if i < len(fixture.blocks): - block = fixture.blocks[i] - if isinstance(block, FixtureBlock): - block.execution_witness = witness - except Exception as e: - raise RuntimeError( - f"Failed to parse witness data from witness-filler tool. " - f"Output was: {result.stdout[:500]}{'...' if len(result.stdout) > 500 else ''}" - ) from e - - return generate_witness diff --git a/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini b/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini index 37b44bf643f..74d65dd5e74 100644 --- a/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini +++ b/packages/testing/src/execution_testing/cli/pytest_commands/pytest_ini_files/pytest-fill.ini @@ -10,7 +10,6 @@ addopts = -p execution_testing.cli.pytest_commands.plugins.forks.forks -p execution_testing.cli.pytest_commands.plugins.concurrency -p execution_testing.cli.pytest_commands.plugins.filler.pre_alloc - -p execution_testing.cli.pytest_commands.plugins.filler.witness -p execution_testing.cli.pytest_commands.plugins.filler.ported_tests -p execution_testing.cli.pytest_commands.plugins.filler.static_filler -p execution_testing.cli.pytest_commands.plugins.shared.benchmarking diff --git a/packages/testing/src/execution_testing/client_clis/cli_types.py b/packages/testing/src/execution_testing/client_clis/cli_types.py index 190192ffcca..bd7faf27523 100644 --- a/packages/testing/src/execution_testing/client_clis/cli_types.py +++ b/packages/testing/src/execution_testing/client_clis/cli_types.py @@ -25,6 +25,7 @@ TransactionException, UndefinedException, ) +from execution_testing.fixtures.blockchain import ExecutionWitness from execution_testing.logging import ( get_logger, ) @@ -289,6 +290,7 @@ class Result(CamelModel): ] = None traces: Traces | None = None opcode_count: OpcodeCount | None = None + execution_witness: ExecutionWitness | None = None TRaw = TypeVar("TRaw") diff --git a/packages/testing/src/execution_testing/fixtures/blockchain.py b/packages/testing/src/execution_testing/fixtures/blockchain.py index 786d6dd34ed..ca92e7888ed 100644 --- a/packages/testing/src/execution_testing/fixtures/blockchain.py +++ b/packages/testing/src/execution_testing/fixtures/blockchain.py @@ -448,6 +448,12 @@ def from_fixture_header( ] +class ExecutionWitness(CamelModel): + """Execution witness containing RLP-encoded trie nodes accessed during block execution.""" + + nodes: List[str] + + class FixtureEngineNewPayload(CamelModel): """ Representation of the `engine_newPayloadVX` information to be sent using @@ -468,6 +474,7 @@ class FixtureEngineNewPayload(CamelModel): ] | None ) = None + execution_witness: ExecutionWitness | None = None def valid(self) -> bool: """Return whether the payload is valid.""" @@ -581,24 +588,6 @@ def from_withdrawal(cls, w: WithdrawalGeneric) -> Self: return cls(**w.model_dump()) -class WitnessChunk(CamelModel): - """Represents execution witness data for a block.""" - - state: List[str] - codes: List[str] - keys: List[str] - headers: List[str] - - @classmethod - def parse_witness_chunks(cls, s: str) -> List[Self]: - """ - Parse multiple witness chunks from JSON string. - - Returns a list of WitnessChunk instances parsed from the JSON array. - """ - return [cls(**obj) for obj in json.loads(s)] - - class FixtureBlockBase(CamelModel): """ Representation of an Ethereum block within a test Fixture without RLP @@ -625,7 +614,7 @@ def strip_block_number_computed_field(cls, data: Any) -> Any: default_factory=list, alias="uncleHeaders" ) withdrawals: List[FixtureWithdrawal] | None = None - execution_witness: WitnessChunk | None = None + execution_witness: ExecutionWitness | None = None fork: Fork | None = Field(None, exclude=True) @computed_field(alias="blocknumber") # type: ignore[prop-decorator] diff --git a/packages/testing/src/execution_testing/specs/blockchain.py b/packages/testing/src/execution_testing/specs/blockchain.py index 06d6bebfc9a..e876edc53e2 100644 --- a/packages/testing/src/execution_testing/specs/blockchain.py +++ b/packages/testing/src/execution_testing/specs/blockchain.py @@ -380,6 +380,7 @@ def get_fixture_block(self) -> FixtureBlock | InvalidFixtureBlock: if self.withdrawals is not None else None ), + execution_witness=self.result.execution_witness, fork=self.fork, ).with_rlp(txs=self.txs) @@ -414,6 +415,7 @@ def get_fixture_engine_new_payload(self) -> FixtureEngineNewPayload: else None, validation_error=self.expected_exception, error_code=self.engine_api_error_code, + execution_witness=self.result.execution_witness, ) def verify_transactions( diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index 1d8bbcc106b..ea26282c47f 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -11,6 +11,7 @@ Entry point for the Ethereum specification. """ +import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -54,8 +55,10 @@ State, TransientStorage, destroy_account, + enable_witness_mode, get_account, increment_nonce, + is_witness_mode_enabled, modify_state, set_account_balance, state_root, @@ -762,6 +765,13 @@ def apply_body( The block output for the current block. """ + # Auto-enable witness mode if WITNESS_MODE env var is set + # This allows validating IncrementalMPT against patricialize + if os.environ.get("WITNESS_MODE") and not is_witness_mode_enabled( + block_env.state + ): + enable_witness_mode(block_env.state) + block_output = vm.BlockOutput() process_unchecked_system_transaction( diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index 6571aa05c61..d2cefe60c07 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -24,7 +24,45 @@ from ethereum_types.numeric import U256, Uint from .fork_types import EMPTY_ACCOUNT, Account, Address, Root -from .trie import EMPTY_TRIE_ROOT, Trie, copy_trie, root, trie_get, trie_set +from .trie import ( + EMPTY_TRIE_ROOT, + IncrementalMPT, + Trie, + Witness, + build_mpt, + copy_trie, + mpt_get, + mpt_root, + mpt_set, + root, + trie_get, + trie_set, +) + + +@dataclass +class WitnessState: + """ + Tracks state for deferred execution witness generation. + + Instead of recording witness nodes during execution, we preserve the + pre-block state and track which keys are accessed (reads) and modified + (writes). The witness is generated after all block execution by building + a fresh IncrementalMPT from the pre-block state, traversing read paths, + and applying the final diff for writes. + """ + + # Pre-block state (preserved at enable_witness_mode time) + pre_block_main_trie_data: Dict[Address, Optional[Account]] + pre_block_storage_tries_data: Dict[Address, Dict[Bytes32, U256]] + + # Dirty tracking during execution (writes) + dirty_accounts: Set[Address] = field(default_factory=set) + dirty_storage: Dict[Address, Set[Bytes32]] = field(default_factory=dict) + + # Access tracking during execution (reads) + accessed_accounts: Set[Address] = field(default_factory=set) + accessed_storage: Dict[Address, Set[Bytes32]] = field(default_factory=dict) @dataclass @@ -46,6 +84,7 @@ class State: ] ] = field(default_factory=list) created_accounts: Set[Address] = field(default_factory=set) + _witness_state: Optional[WitnessState] = None @dataclass @@ -70,6 +109,7 @@ def close_state(state: State) -> None: del state._storage_tries del state._snapshots del state.created_accounts + del state._witness_state def begin_transaction( @@ -135,6 +175,11 @@ def rollback_transaction( transient_storage : TransientStorage The transient storage of the transaction. + Note: Dirty tracking for witness generation persists across rollbacks. + This is correct because the witness needs to capture all nodes that + were accessed during execution, regardless of whether transactions + succeeded or failed. + """ state._main_trie, state._storage_tries = state._snapshots.pop() if not state._snapshots: @@ -189,8 +234,11 @@ def get_account_optional(state: State, address: Address) -> Optional[Account]: Account at address. """ - account = trie_get(state._main_trie, address) - return account + # Track accessed account for execution witness generation + if state._witness_state is not None: + state._witness_state.accessed_accounts.add(address) + + return trie_get(state._main_trie, address) def set_account( @@ -212,6 +260,10 @@ def set_account( """ trie_set(state._main_trie, address, account) + # Track dirty account for deferred witness generation + if state._witness_state is not None: + state._witness_state.dirty_accounts.add(address) + def destroy_account(state: State, address: Address) -> None: """ @@ -245,6 +297,17 @@ def destroy_storage(state: State, address: Address) -> None: Address of account whose storage is to be deleted. """ + # Track all pre-block storage keys as dirty for witness generation + if state._witness_state is not None: + ws = state._witness_state + if address in ws.pre_block_storage_tries_data: + if address not in ws.dirty_storage: + ws.dirty_storage[address] = set() + # Mark all pre-block storage keys as dirty (they're now deleted) + ws.dirty_storage[address].update( + ws.pre_block_storage_tries_data[address].keys() + ) + if address in state._storage_tries: del state._storage_tries[address] @@ -291,12 +354,17 @@ def get_storage(state: State, address: Address, key: Bytes32) -> U256: Value at the key. """ + # Track accessed storage for execution witness generation + if state._witness_state is not None: + ws = state._witness_state + if address not in ws.accessed_storage: + ws.accessed_storage[address] = set() + ws.accessed_storage[address].add(key) + trie = state._storage_tries.get(address) if trie is None: return U256(0) - value = trie_get(trie, key) - assert isinstance(value, U256) return value @@ -330,6 +398,13 @@ def set_storage( if trie._data == {}: del state._storage_tries[address] + # Track dirty storage for deferred witness generation + if state._witness_state is not None: + ws = state._witness_state + if address not in ws.dirty_storage: + ws.dirty_storage[address] = set() + ws.dirty_storage[address].add(key) + def storage_root(state: State, address: Address) -> Root: """ @@ -375,7 +450,18 @@ def state_root(state: State) -> Root: def get_storage_root(address: Address) -> Root: return storage_root(state, address) - return root(state._main_trie, get_storage_root=get_storage_root) + # Calculate root using patricialize (existing implementation) + patricialize_root = root(state._main_trie, get_storage_root=get_storage_root) + + # If witness mode is enabled, verify IncrementalMPT produces same root + if state._witness_state is not None: + incremental_root, _ = generate_witness(state) + assert patricialize_root == incremental_root, ( + f"Root mismatch! patricialize={patricialize_root.hex()} " + f"incremental={incremental_root.hex()}" + ) + + return patricialize_root def account_exists(state: State, address: Address) -> bool: @@ -665,3 +751,206 @@ def set_transient_storage( trie_set(trie, key, value) if trie._data == {}: del transient_storage._tries[address] + + +# ============================================================================= +# Witness Generation Functions +# ============================================================================= + + +def enable_witness_mode(state: State) -> None: + """ + Enable witness tracking mode for the state. + + Preserves the current (pre-block) state and sets up dirty tracking. + The actual witness generation is deferred until after block execution. + + Parameters + ---------- + state : + The state to enable witness mode on. + + """ + assert not state._snapshots, "Cannot enable witness mode during transaction" + + state._witness_state = WitnessState( + pre_block_main_trie_data=dict(state._main_trie._data), + pre_block_storage_tries_data={ + addr: dict(trie._data) + for addr, trie in state._storage_tries.items() + }, + ) + + +def is_witness_mode_enabled(state: State) -> bool: + """ + Check if witness tracking mode is enabled. + + Parameters + ---------- + state : + The state to check. + + Returns + ------- + enabled : `bool` + True if witness mode is enabled. + + """ + return state._witness_state is not None + + +def generate_witness(state: State) -> Tuple[Root, Witness]: + """ + Build MPT from pre-block state, generate execution witness, return root. + + This is called after all block execution completes. It builds a fresh + IncrementalMPT from the pre-block state, traverses read paths to record + pre-state nodes, then applies the final diff for writes. This produces + an execution witness containing nodes needed for: + - Verifying pre-state values that were read + - Re-executing the block + - Computing the post-state root + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + root : `Root` + The state root computed via IncrementalMPT. + witness : `Witness` + The execution witness containing accessed nodes. + + """ + assert state._witness_state is not None + ws = state._witness_state + + # Build fresh MPT from pre-block state + pre_main_trie: Trie[Address, Optional[Account]] = Trie( + secured=True, default=None + ) + pre_main_trie._data = dict(ws.pre_block_main_trie_data) + + # Build pre-block storage MPTs + storage_mpts: Dict[Address, IncrementalMPT[Bytes32, U256]] = {} + for address, data in ws.pre_block_storage_tries_data.items(): + pre_storage_trie: Trie[Bytes32, U256] = Trie( + secured=True, default=U256(0) + ) + pre_storage_trie._data = dict(data) + storage_mpts[address] = build_mpt(pre_storage_trie) + + def get_pre_storage_root(address: Address) -> Root: + if address in ws.pre_block_storage_tries_data: + pre_trie: Trie[Bytes32, U256] = Trie(secured=True, default=U256(0)) + pre_trie._data = dict(ws.pre_block_storage_tries_data[address]) + return root(pre_trie) + return EMPTY_TRIE_ROOT + + main_mpt = build_mpt(pre_main_trie, get_pre_storage_root) + + # 1. Traverse read-only storage keys (accessed but not dirty) + # This records pre-state paths for values that were read + for address, accessed_keys in ws.accessed_storage.items(): + dirty_keys = ws.dirty_storage.get(address, set()) + read_only_keys = accessed_keys - dirty_keys + + if read_only_keys: + if address not in storage_mpts: + # Storage was accessed but didn't exist pre-block + empty_trie: Trie[Bytes32, U256] = Trie( + secured=True, default=U256(0) + ) + storage_mpts[address] = build_mpt(empty_trie) + + for key in read_only_keys: + mpt_get(storage_mpts[address], key) + + # 2. Apply dirty storage (writes) + for address, dirty_keys in ws.dirty_storage.items(): + if address not in storage_mpts: + # New storage created during block + empty_trie: Trie[Bytes32, U256] = Trie( + secured=True, default=U256(0) + ) + storage_mpts[address] = build_mpt(empty_trie) + + storage_trie = state._storage_tries.get(address) + for key in dirty_keys: + value = trie_get(storage_trie, key) if storage_trie else U256(0) + mpt_set(storage_mpts[address], key, value) + + # Accounts are "dirty" if: + # - Account fields changed (nonce/balance/code) - tracked in dirty_accounts + # - Storage changed (storage root changed) - tracked in dirty_storage + all_dirty_accounts = ws.dirty_accounts | set(ws.dirty_storage.keys()) + + # 3. Traverse read-only accounts (accessed but not dirty) + # This records pre-state paths for accounts that were read + read_only_accounts = ws.accessed_accounts - all_dirty_accounts + for address in read_only_accounts: + mpt_get(main_mpt, address) + + # 4. Apply dirty accounts (writes, with current storage roots) + for address in all_dirty_accounts: + account = trie_get(state._main_trie, address) + + # Get storage root for this account + if address in storage_mpts: + addr_storage_root = mpt_root(storage_mpts[address]) + elif address in state._storage_tries: + addr_storage_root = root(state._storage_tries[address]) + else: + addr_storage_root = EMPTY_TRIE_ROOT + + # Use a closure that captures the specific storage root for this address + def make_storage_root_getter( + sr: Root, + ) -> Callable[[Address], Root]: + return lambda _: sr + + mpt_set( + main_mpt, + address, + account, + get_storage_root=make_storage_root_getter(addr_storage_root), + ) + + # Collect witness from all MPTs + witness = Witness( + accessed_nodes=dict(main_mpt.witness.accessed_nodes), + accessed_keys=set(main_mpt.witness.accessed_keys), + ) + for mpt in storage_mpts.values(): + witness.accessed_nodes.update(mpt.witness.accessed_nodes) + witness.accessed_keys.update(mpt.witness.accessed_keys) + + return mpt_root(main_mpt), witness + + +def get_witness(state: State) -> Witness: + """ + Get the collected witness data from the state. + + This generates the witness by building an MPT from the pre-block state + and applying only the final diff. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + witness : `Witness` + The witness data containing accessed nodes. + + """ + if state._witness_state is None: + return Witness() + + _, witness = generate_witness(state) + return witness diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index fea8e0ece48..b2da1381818 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -23,8 +23,10 @@ MutableMapping, Optional, Sequence, + Set, Tuple, TypeVar, + Union, cast, ) @@ -135,6 +137,68 @@ class BranchNode: InternalNode = LeafNode | ExtensionNode | BranchNode +# Mutable node types for incremental MPT updates +@dataclass +class MutableLeafNode: + """Mutable leaf node in the Merkle Trie for in-place updates.""" + + rest_of_key: Bytes + value: Bytes + _hash: Optional[Bytes] = None # Cached hash, invalidated on change + _rlp: Optional[Bytes] = None # Cached RLP encoding + + +@dataclass +class MutableExtensionNode: + """Mutable extension node in the Merkle Trie for in-place updates.""" + + key_segment: Bytes + child: "MutableNode" + _hash: Optional[Bytes] = None + _rlp: Optional[Bytes] = None + + +@dataclass +class MutableBranchNode: + """Mutable branch node in the Merkle Trie for in-place updates.""" + + children: List[Optional["MutableNode"]] # 16 children slots + value: Bytes # Value if key terminates at this branch + _hash: Optional[Bytes] = None + _rlp: Optional[Bytes] = None + + +MutableNode = Union[ + MutableLeafNode, MutableExtensionNode, MutableBranchNode, None +] + + +@dataclass +class Witness: + """Tracks nodes accessed during trie operations for witness generation.""" + + accessed_nodes: Dict[Bytes, Bytes] = field( + default_factory=dict + ) # hash -> RLP encoding + accessed_keys: Set[Bytes] = field(default_factory=set) # Original keys + + +@dataclass +class IncrementalMPT(Generic[K, V]): + """ + An MPT that supports incremental updates and witness tracking. + + This maintains an actual tree structure that can be updated in-place, + rather than rebuilding the entire tree on each root calculation. + """ + + secured: bool + default: V + root_node: MutableNode = None + witness: Witness = field(default_factory=Witness) + _data: Dict[K, V] = field(default_factory=dict) # For backward compat + + def encode_internal_node(node: Optional[InternalNode]) -> Extended: """ Encodes a Merkle Trie node into its RLP form. The RLP will then be @@ -506,3 +570,681 @@ def patricialize( cast(BranchSubnodes, assert_type(subnodes, Tuple[Extended, ...])), value, ) + + +# ============================================================================= +# Incremental MPT Functions +# ============================================================================= + + +def _build_mutable_tree( + obj: Mapping[Bytes, Bytes], level: Uint +) -> MutableNode: + """ + Build a mutable tree structure from a prepared key-value mapping. + + This is similar to `patricialize()` but creates mutable nodes for + in-place updates. + + Parameters + ---------- + obj : + Underlying trie key-value pairs, with keys in nibble-list format. + level : + Current trie level. + + Returns + ------- + node : `MutableNode` + Root node of the mutable tree. + + """ + if len(obj) == 0: + return None + + arbitrary_key = next(iter(obj)) + + # Leaf node case + if len(obj) == 1: + return MutableLeafNode( + rest_of_key=arbitrary_key[level:], + value=obj[arbitrary_key], + ) + + # Check for common prefix (extension node) + substring = arbitrary_key[level:] + prefix_length = len(substring) + for key in obj: + prefix_length = min( + prefix_length, common_prefix_length(substring, key[level:]) + ) + if prefix_length == 0: + break + + if prefix_length > 0: + prefix = arbitrary_key[int(level) : int(level) + prefix_length] + child = _build_mutable_tree(obj, level + Uint(prefix_length)) + return MutableExtensionNode(key_segment=prefix, child=child) + + # Branch node case + branches: List[MutableMapping[Bytes, Bytes]] = [] + for _ in range(16): + branches.append({}) + value = b"" + + for key in obj: + if len(key) == level: + value = obj[key] + else: + branches[key[level]][key] = obj[key] + + children: List[Optional[MutableNode]] = [ + _build_mutable_tree(branches[k], level + Uint(1)) for k in range(16) + ] + + return MutableBranchNode(children=children, value=value) + + +def build_mpt( + trie: Trie[K, V], + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> IncrementalMPT[K, V]: + """ + Build an IncrementalMPT from an existing Trie. + + This is called with the pre-execution state to create a mutable + tree structure that can be updated in-place during execution. + + Parameters + ---------- + trie : + The source Trie to build from. + get_storage_root : + Function to get the storage root of an account. + + Returns + ------- + mpt : `IncrementalMPT[K, V]` + An incremental MPT with the same data. + + """ + prepared = _prepare_trie(trie, get_storage_root) + root_node = _build_mutable_tree(prepared, Uint(0)) + + return IncrementalMPT( + secured=trie.secured, + default=trie.default, + root_node=root_node, + _data=copy.copy(trie._data), + ) + + +def _invalidate_hash(node: MutableNode) -> None: + """Invalidate the cached hash of a node.""" + if node is not None: + node._hash = None + node._rlp = None + + +def _record_witness( + witness: Witness, node: MutableNode, key: Optional[Bytes] = None +) -> None: + """Record a node access in the witness.""" + if node is None: + return + + # Record the key if provided + if key is not None: + witness.accessed_keys.add(key) + + # Compute hash and RLP if not cached + node_hash, node_rlp = _compute_node_hash_and_rlp(node) + if node_hash is not None and node_hash not in witness.accessed_nodes: + witness.accessed_nodes[node_hash] = node_rlp + + +def _encode_mutable_node(node: MutableNode) -> Extended: + """ + Encode a mutable node to its RLP form (unencoded tuple or hash). + + Similar to encode_internal_node but for mutable nodes. + """ + if node is None: + return b"" + elif isinstance(node, MutableLeafNode): + return ( + nibble_list_to_compact(node.rest_of_key, True), + node.value, + ) + elif isinstance(node, MutableExtensionNode): + child_encoded = _encode_mutable_node_to_extended(node.child) + return ( + nibble_list_to_compact(node.key_segment, False), + child_encoded, + ) + elif isinstance(node, MutableBranchNode): + children_encoded = [ + _encode_mutable_node_to_extended(child) + for child in node.children + ] + return children_encoded + [node.value] + else: + raise AssertionError(f"Invalid mutable node type {type(node)}!") + + +def _encode_mutable_node_to_extended(node: MutableNode) -> Extended: + """ + Encode a mutable node for embedding in parent. + + Returns the hash if RLP >= 32 bytes, otherwise returns unencoded form. + """ + if node is None: + return b"" + + unencoded = _encode_mutable_node(node) + encoded = rlp.encode(unencoded) + + if len(encoded) < 32: + return unencoded + else: + return keccak256(encoded) + + +def _compute_node_hash_and_rlp( + node: MutableNode, +) -> Tuple[Optional[Bytes], Bytes]: + """ + Compute the hash and RLP encoding of a node. + + Returns (hash, rlp) where hash may be None for small nodes. + """ + if node is None: + return None, b"" + + # Use cached values if available + if node._rlp is not None: + if node._hash is not None: + return node._hash, node._rlp + elif len(node._rlp) >= 32: + return keccak256(node._rlp), node._rlp + + unencoded = _encode_mutable_node(node) + encoded = rlp.encode(unencoded) + + # Cache the RLP + node._rlp = encoded + + if len(encoded) >= 32: + node._hash = keccak256(encoded) + return node._hash, encoded + else: + return None, encoded + + +def mpt_get(mpt: IncrementalMPT[K, V], key: K) -> V: + """ + Get a value from the incremental MPT. + + Traverses the tree and records accessed nodes in the witness for + execution witness generation. + + Parameters + ---------- + mpt : + The incremental MPT to get from. + key : + Key to lookup. + + Returns + ------- + value : `V` + Value at the key, or the default value if not found. + + """ + # Get from flat data for consistency + value = mpt._data.get(key, mpt.default) + + # Record the key access + if mpt.secured: + nibble_key = bytes_to_nibble_list(keccak256(key)) + else: + nibble_key = bytes_to_nibble_list(key) + + mpt.witness.accessed_keys.add(key) + + # Traverse tree and record witness nodes + _mpt_traverse_for_witness(mpt, mpt.root_node, nibble_key, Uint(0)) + + return value + + +def _mpt_traverse_for_witness( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + level: Uint, +) -> None: + """Traverse the tree recording nodes in the witness.""" + if node is None: + return + + _record_witness(mpt.witness, node) + + if isinstance(node, MutableLeafNode): + # Leaf node - end of path + pass + elif isinstance(node, MutableExtensionNode): + # Extension node - follow if key matches + segment_len = len(node.key_segment) + lvl = int(level) + if key[lvl : lvl + segment_len] == node.key_segment: + _mpt_traverse_for_witness( + mpt, node.child, key, Uint(lvl + segment_len) + ) + elif isinstance(node, MutableBranchNode): + # Branch node - follow appropriate child + lvl = int(level) + if lvl < len(key): + child_idx = key[lvl] + _mpt_traverse_for_witness( + mpt, node.children[child_idx], key, Uint(lvl + 1) + ) + + +def mpt_set( + mpt: IncrementalMPT[K, V], + key: K, + value: V, + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> None: + """ + Set a value in the incremental MPT. + + Updates the tree in-place and invalidates cached hashes along the path. + + Parameters + ---------- + mpt : + The incremental MPT to update. + key : + Key to set. + value : + Value to set at the key. + get_storage_root : + Function to get storage root (for Account values). + + """ + # Update flat data for backward compatibility + if value == mpt.default: + if key in mpt._data: + del mpt._data[key] + else: + mpt._data[key] = value + + # Prepare key and value + if mpt.secured: + nibble_key = bytes_to_nibble_list(keccak256(key)) + else: + nibble_key = bytes_to_nibble_list(key) + + # Encode the value + if value == mpt.default: + encoded_value = b"" + elif isinstance(value, Account): + assert get_storage_root is not None + address = Address(key) + encoded_value = encode_node(value, get_storage_root(address)) + else: + encoded_value = encode_node(value) + + # Update tree + if encoded_value == b"": + # Delete operation + mpt.root_node = _mpt_delete_node( + mpt, mpt.root_node, nibble_key, Uint(0) + ) + else: + # Insert/update operation + mpt.root_node = _mpt_insert_node( + mpt, mpt.root_node, nibble_key, encoded_value, Uint(0) + ) + + +def _mpt_insert_node( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """ + Insert or update a value in the mutable tree. + + Returns the new/updated node for this position. + """ + _record_witness(mpt.witness, node) + + if node is None: + # Empty slot - create new leaf + return MutableLeafNode(rest_of_key=key[level:], value=value) + + _invalidate_hash(node) + + if isinstance(node, MutableLeafNode): + return _insert_into_leaf(mpt, node, key, value, level) + elif isinstance(node, MutableExtensionNode): + return _insert_into_extension(mpt, node, key, value, level) + elif isinstance(node, MutableBranchNode): + return _insert_into_branch(mpt, node, key, value, level) + else: + raise AssertionError(f"Invalid node type {type(node)}") + + +def _insert_into_leaf( + mpt: IncrementalMPT, + node: MutableLeafNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is a leaf.""" + existing_key = node.rest_of_key + remaining_key = key[level:] + + if existing_key == remaining_key: + # Same key - update value + node.value = value + return node + + # Keys differ - need to create branch + prefix_len = common_prefix_length(existing_key, remaining_key) + + # Create new branch or extension + branch + if prefix_len > 0: + # Common prefix - create extension then branch + branch = _create_branch_from_two_leaves( + existing_key[prefix_len:], + node.value, + remaining_key[prefix_len:], + value, + ) + return MutableExtensionNode( + key_segment=existing_key[:prefix_len], child=branch + ) + else: + # No common prefix - create branch directly + return _create_branch_from_two_leaves( + existing_key, node.value, remaining_key, value + ) + + +def _create_branch_from_two_leaves( + key1: Bytes, value1: Bytes, key2: Bytes, value2: Bytes +) -> MutableBranchNode: + """Create a branch node from two key-value pairs.""" + children: List[Optional[MutableNode]] = [None] * 16 + branch_value = b"" + + if len(key1) == 0: + branch_value = value1 + else: + idx1 = key1[0] + children[idx1] = MutableLeafNode(rest_of_key=key1[1:], value=value1) + + if len(key2) == 0: + branch_value = value2 + else: + idx2 = key2[0] + children[idx2] = MutableLeafNode(rest_of_key=key2[1:], value=value2) + + return MutableBranchNode(children=children, value=branch_value) + + +def _insert_into_extension( + mpt: IncrementalMPT, + node: MutableExtensionNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is an extension.""" + remaining_key = key[level:] + segment = node.key_segment + prefix_len = common_prefix_length(segment, remaining_key) + + if prefix_len == len(segment): + # Key follows extension completely - recurse into child + node.child = _mpt_insert_node( + mpt, node.child, key, value, level + Uint(prefix_len) + ) + return node + + # Extension needs to be split + if prefix_len > 0: + # Partial match - create new extension for common prefix + new_child = _split_extension( + node, remaining_key, value, prefix_len + ) + return MutableExtensionNode( + key_segment=segment[:prefix_len], child=new_child + ) + else: + # No common prefix - create branch at this level + return _split_extension(node, remaining_key, value, 0) + + +def _split_extension( + node: MutableExtensionNode, + remaining_key: Bytes, + value: Bytes, + prefix_len: int, +) -> MutableNode: + """Split an extension node when keys diverge.""" + segment = node.key_segment + children: List[Optional[MutableNode]] = [None] * 16 + branch_value = b"" + + # Place existing extension's child + segment_after_prefix = segment[prefix_len:] + if len(segment_after_prefix) == 1: + # Single nibble left - place child directly in branch + idx = segment_after_prefix[0] + children[idx] = node.child + elif len(segment_after_prefix) > 1: + # Multiple nibbles - create new extension + idx = segment_after_prefix[0] + children[idx] = MutableExtensionNode( + key_segment=segment_after_prefix[1:], child=node.child + ) + + # Place new value + key_after_prefix = remaining_key[prefix_len:] + if len(key_after_prefix) == 0: + branch_value = value + else: + idx = key_after_prefix[0] + if children[idx] is None: + children[idx] = MutableLeafNode( + rest_of_key=key_after_prefix[1:], value=value + ) + else: + # Need to merge with existing child (shouldn't happen normally) + raise AssertionError("Unexpected collision during split") + + return MutableBranchNode(children=children, value=branch_value) + + +def _insert_into_branch( + mpt: IncrementalMPT, + node: MutableBranchNode, + key: Bytes, + value: Bytes, + level: Uint, +) -> MutableNode: + """Handle insertion when current node is a branch.""" + remaining_key = key[level:] + + if len(remaining_key) == 0: + # Value terminates at this branch + node.value = value + return node + + # Recurse into appropriate child + child_idx = remaining_key[0] + node.children[child_idx] = _mpt_insert_node( + mpt, node.children[child_idx], key, value, level + Uint(1) + ) + return node + + +def _mpt_delete_node( + mpt: IncrementalMPT, + node: MutableNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """ + Delete a key from the mutable tree. + + Returns the updated node (may be different type or None). + """ + _record_witness(mpt.witness, node) + + if node is None: + return None + + _invalidate_hash(node) + + if isinstance(node, MutableLeafNode): + if node.rest_of_key == key[level:]: + return None # Key found, delete + return node # Key not found, no change + elif isinstance(node, MutableExtensionNode): + return _delete_from_extension(mpt, node, key, level) + elif isinstance(node, MutableBranchNode): + return _delete_from_branch(mpt, node, key, level) + else: + raise AssertionError(f"Invalid node type {type(node)}") + + +def _delete_from_extension( + mpt: IncrementalMPT, + node: MutableExtensionNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """Handle deletion when current node is an extension.""" + segment = node.key_segment + remaining_key = key[level:] + prefix_len = common_prefix_length(segment, remaining_key) + + if prefix_len < len(segment): + return node # Key doesn't follow this extension + + # Recurse into child + new_child = _mpt_delete_node( + mpt, node.child, key, level + Uint(len(segment)) + ) + + if new_child is None: + return None + + # Collapse if child is now an extension + if isinstance(new_child, MutableExtensionNode): + return MutableExtensionNode( + key_segment=segment + new_child.key_segment, + child=new_child.child, + ) + elif isinstance(new_child, MutableLeafNode): + # Merge extension into leaf + return MutableLeafNode( + rest_of_key=segment + new_child.rest_of_key, + value=new_child.value, + ) + + node.child = new_child + return node + + +def _delete_from_branch( + mpt: IncrementalMPT, + node: MutableBranchNode, + key: Bytes, + level: Uint, +) -> MutableNode: + """Handle deletion when current node is a branch.""" + remaining_key = key[level:] + + if len(remaining_key) == 0: + # Delete value at this branch + node.value = b"" + else: + # Delete from child + child_idx = remaining_key[0] + node.children[child_idx] = _mpt_delete_node( + mpt, node.children[child_idx], key, level + Uint(1) + ) + + # Check if branch can be collapsed + return _collapse_branch(mpt, node) + + +def _collapse_branch(mpt: IncrementalMPT, node: MutableBranchNode) -> MutableNode: + """Collapse a branch node if it has only one child and no value.""" + non_empty = [(i, c) for i, c in enumerate(node.children) if c is not None] + + if len(non_empty) == 0 and node.value == b"": + return None + + if len(non_empty) == 1 and node.value == b"": + idx, child = non_empty[0] + _record_witness(mpt.witness, child) # Record the surviving child + nibble = Bytes([idx]) + + if isinstance(child, MutableLeafNode): + return MutableLeafNode( + rest_of_key=nibble + child.rest_of_key, + value=child.value, + ) + elif isinstance(child, MutableExtensionNode): + return MutableExtensionNode( + key_segment=nibble + child.key_segment, + child=child.child, + ) + else: + # Child is a branch - create extension + return MutableExtensionNode(key_segment=nibble, child=child) + + if len(non_empty) == 0 and node.value != b"": + # Only value at this branch - convert to leaf + return MutableLeafNode(rest_of_key=b"", value=node.value) + + return node + + +def mpt_root(mpt: IncrementalMPT) -> Root: + """ + Compute the root hash of the incremental MPT. + + Uses cached hashes where available for efficiency. + + Parameters + ---------- + mpt : + The incremental MPT. + + Returns + ------- + root : `Root` + The MPT root hash. + + """ + if mpt.root_node is None: + return EMPTY_TRIE_ROOT + + root_encoded = _encode_mutable_node_to_extended(mpt.root_node) + + if isinstance(root_encoded, Bytes): + return Root(root_encoded) + else: + return keccak256(rlp.encode(root_encoded)) diff --git a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py index 9a14efa54ca..8809a46688e 100644 --- a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py +++ b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py @@ -303,6 +303,16 @@ def close_state(self) -> Any: """close_state function of the fork.""" return self._module("state").close_state + @property + def enable_witness_mode(self) -> Any: + """enable_witness_mode function of the fork (Osaka+ only).""" + return self._module("state").enable_witness_mode + + @property + def get_witness(self) -> Any: + """get_witness function of the fork (Osaka+ only).""" + return self._module("state").get_witness + @property def create_ether(self) -> Any: """create_ether function of the fork.""" diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index 4988ef25bcd..afb6a504562 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -357,6 +357,11 @@ def run_state_test(self) -> Any: """ block_env = self.block_environment() block_output = self.fork.BlockOutput() + + # Enable witness mode for Osaka+ forks + if hasattr(self.fork, "enable_witness_mode"): + self.fork.enable_witness_mode(block_env.state) + self.backup_state() if len(self.txs.transactions) > 0: tx = self.txs.transactions[0] @@ -448,6 +453,10 @@ def run_blockchain_test(self) -> None: block_env = self.block_environment() block_output = self.fork.BlockOutput() + # Enable witness mode for Osaka+ forks + if hasattr(self.fork, "enable_witness_mode"): + self.fork.enable_witness_mode(block_env.state) + try: self._run_blockchain_test(block_env, block_output) except InvalidBlock as e: diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index 544838a5dbb..cca014e9636 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -270,6 +270,7 @@ class Result: block_exception: Optional[str] = None block_access_list: Optional[Any] = None block_access_list_hash: Optional[Hash32] = None + execution_witness: Optional[Dict[str, List[str]]] = None def get_receipts_from_output( self, @@ -332,6 +333,17 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: block_output.block_access_list ) ) + # Extract execution witness for Osaka+ forks + if hasattr(t8n.fork, "get_witness"): + witness = t8n.fork.get_witness(block_env.state) + if witness.accessed_nodes: + self.execution_witness = { + "nodes": [ + "0x" + node_rlp.hex() + for node_rlp in witness.accessed_nodes.values() + ] + } + @staticmethod def _block_access_list_to_json(account_changes: Any) -> Any: @@ -398,7 +410,7 @@ def _block_access_list_to_json(account_changes: Any) -> Any: json_account_changes.append(account_data) return json_account_changes - + def json_encode_receipts(self) -> Any: """ Encode receipts to JSON. @@ -476,5 +488,7 @@ def to_json(self) -> Any: data["blockAccessListHash"] = encode_to_hex( self.block_access_list_hash ) + if self.execution_witness is not None: + data["executionWitness"] = self.execution_witness return data From 91f4ccfc71e3ac312defd8916725d77e2b9fbe60 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 17:31:19 -0300 Subject: [PATCH 2/8] track bytecode --- .../execution_testing/fixtures/blockchain.py | 3 +- src/ethereum/forks/osaka/fork.py | 3 ++ src/ethereum/forks/osaka/state.py | 33 +++++++++++++++++++ src/ethereum/forks/osaka/trie.py | 1 + src/ethereum/forks/osaka/utils/message.py | 3 +- src/ethereum/forks/osaka/vm/eoa_delegation.py | 10 +++++- .../osaka/vm/instructions/environment.py | 4 ++- src/ethereum/forks/osaka/vm/interpreter.py | 2 ++ .../evm_tools/t8n/t8n_types.py | 7 ++-- 9 files changed, 60 insertions(+), 6 deletions(-) diff --git a/packages/testing/src/execution_testing/fixtures/blockchain.py b/packages/testing/src/execution_testing/fixtures/blockchain.py index ca92e7888ed..2995d3ef775 100644 --- a/packages/testing/src/execution_testing/fixtures/blockchain.py +++ b/packages/testing/src/execution_testing/fixtures/blockchain.py @@ -449,9 +449,10 @@ def from_fixture_header( class ExecutionWitness(CamelModel): - """Execution witness containing RLP-encoded trie nodes accessed during block execution.""" + """Execution witness containing RLP-encoded trie nodes and bytecodes accessed during block execution.""" nodes: List[str] + bytecodes: List[str] = [] class FixtureEngineNewPayload(CamelModel): diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index ea26282c47f..9fe49d0659f 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -62,6 +62,7 @@ modify_state, set_account_balance, state_root, + track_bytecode_access, ) from .transactions import ( AccessListTransaction, @@ -679,6 +680,7 @@ def process_checked_system_transaction( """ system_contract_code = get_account(block_env.state, target_address).code + track_bytecode_access(block_env.state, system_contract_code) if len(system_contract_code) == 0: raise InvalidBlock( @@ -727,6 +729,7 @@ def process_unchecked_system_transaction( """ system_contract_code = get_account(block_env.state, target_address).code + track_bytecode_access(block_env.state, system_contract_code) return process_system_transaction( block_env, target_address, diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index d2cefe60c07..116b342a308 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -23,6 +23,8 @@ from ethereum_types.frozen import modify from ethereum_types.numeric import U256, Uint +from ethereum.crypto.hash import keccak256 + from .fork_types import EMPTY_ACCOUNT, Account, Address, Root from .trie import ( EMPTY_TRIE_ROOT, @@ -64,6 +66,9 @@ class WitnessState: accessed_accounts: Set[Address] = field(default_factory=set) accessed_storage: Dict[Address, Set[Bytes32]] = field(default_factory=dict) + # Bytecode tracking (code_hash -> bytecode) + accessed_bytecodes: Dict[Bytes32, Bytes] = field(default_factory=dict) + @dataclass class State: @@ -285,6 +290,33 @@ def destroy_account(state: State, address: Address) -> None: set_account(state, address, None) +def track_bytecode_access(state: State, code: Bytes) -> None: + """ + Track bytecode access for execution witness generation. + + Should be called when bytecode is accessed for execution purposes + (CALL variants, EXTCODESIZE, EXTCODECOPY, system contracts). + + Parameters + ---------- + state : State + The state with optional witness tracking. + code : Bytes + The bytecode being accessed. + """ + if state._witness_state is None: + return + + # Skip empty bytecode (EOAs) + if len(code) == 0: + return + + # Compute hash and store for deduplication + code_hash = Bytes32(keccak256(code)) + if code_hash not in state._witness_state.accessed_bytecodes: + state._witness_state.accessed_bytecodes[code_hash] = code + + def destroy_storage(state: State, address: Address) -> None: """ Completely remove the storage at `address`. @@ -923,6 +955,7 @@ def make_storage_root_getter( witness = Witness( accessed_nodes=dict(main_mpt.witness.accessed_nodes), accessed_keys=set(main_mpt.witness.accessed_keys), + bytecodes=sorted(ws.accessed_bytecodes.values()), ) for mpt in storage_mpts.values(): witness.accessed_nodes.update(mpt.witness.accessed_nodes) diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index b2da1381818..8aa5bfe2aff 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -181,6 +181,7 @@ class Witness: default_factory=dict ) # hash -> RLP encoding accessed_keys: Set[Bytes] = field(default_factory=set) # Original keys + bytecodes: List[Bytes] = field(default_factory=list) # Accessed bytecodes @dataclass diff --git a/src/ethereum/forks/osaka/utils/message.py b/src/ethereum/forks/osaka/utils/message.py index ecdc3143234..efd79d8eef9 100644 --- a/src/ethereum/forks/osaka/utils/message.py +++ b/src/ethereum/forks/osaka/utils/message.py @@ -16,7 +16,7 @@ from ethereum_types.numeric import Uint from ..fork_types import Address -from ..state import get_account +from ..state import get_account, track_bytecode_access from ..transactions import Transaction from ..vm import BlockEnvironment, Message, TransactionEnvironment from ..vm.precompiled_contracts.mapping import PRE_COMPILED_CONTRACTS @@ -63,6 +63,7 @@ def prepare_message( current_target = tx.to msg_data = tx.data code = get_account(block_env.state, tx.to).code + track_bytecode_access(block_env.state, code) code_address = tx.to else: raise AssertionError("Target must be address or empty bytes") diff --git a/src/ethereum/forks/osaka/vm/eoa_delegation.py b/src/ethereum/forks/osaka/vm/eoa_delegation.py index e6dd0f12011..8ad754cbe06 100644 --- a/src/ethereum/forks/osaka/vm/eoa_delegation.py +++ b/src/ethereum/forks/osaka/vm/eoa_delegation.py @@ -13,7 +13,13 @@ from ethereum.exceptions import InvalidBlock, InvalidSignatureError from ..fork_types import Address, Authorization -from ..state import account_exists, get_account, increment_nonce, set_code +from ..state import ( + account_exists, + get_account, + increment_nonce, + set_code, + track_bytecode_access, +) from ..utils.hexadecimal import hex_to_address from ..vm.gas import GAS_COLD_ACCOUNT_ACCESS, GAS_WARM_ACCESS from . import Evm, Message @@ -136,6 +142,7 @@ def access_delegation( state = evm.message.block_env.state code = get_account(state, address).code + track_bytecode_access(state, code) if not is_valid_delegation(code): return False, address, code, Uint(0) @@ -146,6 +153,7 @@ def access_delegation( evm.accessed_addresses.add(address) access_gas_cost = GAS_COLD_ACCOUNT_ACCESS code = get_account(state, address).code + track_bytecode_access(state, code) return True, address, code, access_gas_cost diff --git a/src/ethereum/forks/osaka/vm/instructions/environment.py b/src/ethereum/forks/osaka/vm/instructions/environment.py index 28c595ee514..ba861ed5e1b 100644 --- a/src/ethereum/forks/osaka/vm/instructions/environment.py +++ b/src/ethereum/forks/osaka/vm/instructions/environment.py @@ -18,7 +18,7 @@ from ethereum.utils.numeric import ceil32 from ...fork_types import EMPTY_ACCOUNT -from ...state import get_account +from ...state import get_account, track_bytecode_access from ...utils.address import to_address_masked from ...vm.memory import buffer_read, memory_write from .. import Evm @@ -351,6 +351,7 @@ def extcodesize(evm: Evm) -> None: # OPERATION code = get_account(evm.message.block_env.state, address).code + track_bytecode_access(evm.message.block_env.state, code) codesize = U256(len(code)) push(evm.stack, codesize) @@ -393,6 +394,7 @@ def extcodecopy(evm: Evm) -> None: # OPERATION evm.memory += b"\x00" * extend_memory.expand_by code = get_account(evm.message.block_env.state, address).code + track_bytecode_access(evm.message.block_env.state, code) value = buffer_read(code, code_start_index, size) memory_write(evm.memory, memory_start_index, value) diff --git a/src/ethereum/forks/osaka/vm/interpreter.py b/src/ethereum/forks/osaka/vm/interpreter.py index 2d5eb64aff9..43ac1521c45 100644 --- a/src/ethereum/forks/osaka/vm/interpreter.py +++ b/src/ethereum/forks/osaka/vm/interpreter.py @@ -43,6 +43,7 @@ move_ether, rollback_transaction, set_code, + track_bytecode_access, ) from ..vm import Message from ..vm.eoa_delegation import get_delegated_code_address, set_delegation @@ -131,6 +132,7 @@ def process_message_call(message: Message) -> MessageCallOutput: message.disable_precompiles = True message.accessed_addresses.add(delegated_address) message.code = get_account(block_env.state, delegated_address).code + track_bytecode_access(block_env.state, message.code) message.code_address = delegated_address evm = process_message(message) diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index cca014e9636..a86f067d798 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -336,12 +336,15 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: # Extract execution witness for Osaka+ forks if hasattr(t8n.fork, "get_witness"): witness = t8n.fork.get_witness(block_env.state) - if witness.accessed_nodes: + if witness.accessed_nodes or witness.bytecodes: self.execution_witness = { "nodes": [ "0x" + node_rlp.hex() for node_rlp in witness.accessed_nodes.values() - ] + ], + "bytecodes": [ + "0x" + bytecode.hex() for bytecode in witness.bytecodes + ], } From 208cfc481a1ba8e547d482bfa98d6966c7ba5cc8 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 17:34:52 -0300 Subject: [PATCH 3/8] cleanup --- src/ethereum/forks/osaka/fork.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index 9fe49d0659f..9112f5de65c 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -11,7 +11,6 @@ Entry point for the Ethereum specification. """ -import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -768,13 +767,6 @@ def apply_body( The block output for the current block. """ - # Auto-enable witness mode if WITNESS_MODE env var is set - # This allows validating IncrementalMPT against patricialize - if os.environ.get("WITNESS_MODE") and not is_witness_mode_enabled( - block_env.state - ): - enable_witness_mode(block_env.state) - block_output = vm.BlockOutput() process_unchecked_system_transaction( From 2331b719c3229201325bbeb9d7a9014ac6ff2229 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 18:53:45 -0300 Subject: [PATCH 4/8] add ancestors in witness --- .../execution_testing/fixtures/blockchain.py | 1 + .../src/execution_testing/specs/blockchain.py | 4 ++ .../test_types/block_types.py | 1 + src/ethereum/forks/osaka/fork.py | 42 ++++++++++++ src/ethereum/forks/osaka/state.py | 68 +++++++++++++++++++ src/ethereum/forks/osaka/trie.py | 1 + .../forks/osaka/vm/instructions/block.py | 3 + .../evm_tools/loaders/fork_loader.py | 10 +++ .../evm_tools/t8n/__init__.py | 17 +++++ src/ethereum_spec_tools/evm_tools/t8n/env.py | 29 ++++++-- .../evm_tools/t8n/t8n_types.py | 23 ++++--- 11 files changed, 184 insertions(+), 15 deletions(-) diff --git a/packages/testing/src/execution_testing/fixtures/blockchain.py b/packages/testing/src/execution_testing/fixtures/blockchain.py index 2995d3ef775..5b8ca852456 100644 --- a/packages/testing/src/execution_testing/fixtures/blockchain.py +++ b/packages/testing/src/execution_testing/fixtures/blockchain.py @@ -453,6 +453,7 @@ class ExecutionWitness(CamelModel): nodes: List[str] bytecodes: List[str] = [] + ancestors: List[str] = [] class FixtureEngineNewPayload(CamelModel): diff --git a/packages/testing/src/execution_testing/specs/blockchain.py b/packages/testing/src/execution_testing/specs/blockchain.py index e876edc53e2..7d4083117e5 100644 --- a/packages/testing/src/execution_testing/specs/blockchain.py +++ b/packages/testing/src/execution_testing/specs/blockchain.py @@ -96,6 +96,7 @@ def environment_from_parent_header(parent: "FixtureHeader") -> "Environment": parent_gas_limit=parent.gas_limit, parent_ommers_hash=parent.ommers_hash, block_hashes={parent.number: parent.block_hash}, + block_headers={parent.number: parent.rlp}, ) @@ -115,6 +116,9 @@ def apply_new_parent( block_hashes = env.block_hashes.copy() block_hashes[new_parent.number] = new_parent.block_hash updated["block_hashes"] = block_hashes + block_headers = env.block_headers.copy() + block_headers[new_parent.number] = new_parent.rlp + updated["block_headers"] = block_headers return env.copy(**updated) diff --git a/packages/testing/src/execution_testing/test_types/block_types.py b/packages/testing/src/execution_testing/test_types/block_types.py index 2c73a558b98..3f5077a8189 100644 --- a/packages/testing/src/execution_testing/test_types/block_types.py +++ b/packages/testing/src/execution_testing/test_types/block_types.py @@ -137,6 +137,7 @@ def strip_computed_fields(cls, data: Any) -> Any: parent_beacon_block_root: Hash | None = Field(None) block_hashes: Dict[ZeroPaddedHexNumber, Hash] = Field(default_factory=dict) + block_headers: Dict[ZeroPaddedHexNumber, Bytes] = Field(default_factory=dict) ommers: List[Hash] = Field(default_factory=list) withdrawals: List[Withdrawal] | None = Field(None) extra_data: Bytes = Field(Bytes(b"\x00"), exclude=True) diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index 9112f5de65c..eb6725f4719 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -56,11 +56,14 @@ destroy_account, enable_witness_mode, get_account, + get_witness, increment_nonce, is_witness_mode_enabled, modify_state, set_account_balance, + set_witness_metadata, state_root, + track_block_hash_access, track_bytecode_access, ) from .transactions import ( @@ -194,6 +197,35 @@ def get_last_256_block_hashes(chain: BlockChain) -> List[Hash32]: return recent_block_hashes +def get_last_256_block_headers(chain: BlockChain) -> List[Bytes]: + """ + Obtain the list of RLP-encoded headers of the previous 256 blocks. + + This function will return less headers for the first 256 blocks. + The headers are parallel to the hashes from get_last_256_block_hashes. + + Parameters + ---------- + chain : + History and current state. + + Returns + ------- + recent_block_headers : `List[Bytes]` + RLP-encoded headers of recent 256 blocks in order of increasing number. + """ + recent_blocks = chain.blocks[-256:] + if len(recent_blocks) == 0: + return [] + + recent_block_headers: List[Bytes] = [] + for block in recent_blocks: + header_rlp = rlp.encode(block.header) + recent_block_headers.append(header_rlp) + + return recent_block_headers + + def state_transition(chain: BlockChain, block: Block) -> None: """ Attempts to apply a block to an existing block chain. @@ -238,6 +270,13 @@ def state_transition(chain: BlockChain, block: Block) -> None: parent_beacon_block_root=block.header.parent_beacon_block_root, ) + # Set witness metadata if tracking is enabled + set_witness_metadata( + block_env.state, + block.header.number, + get_last_256_block_headers(chain), + ) + block_output = apply_body( block_env=block_env, transactions=block.transactions, @@ -775,6 +814,9 @@ def apply_body( data=block_env.parent_beacon_block_root, ) + # Track parent block access for witness (EIP-2935 system call) + track_block_hash_access(block_env.state, block_env.number - Uint(1)) + process_unchecked_system_transaction( block_env=block_env, target_address=HISTORY_STORAGE_ADDRESS, diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index 116b342a308..66d5c178f2c 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -69,6 +69,14 @@ class WitnessState: # Bytecode tracking (code_hash -> bytecode) accessed_bytecodes: Dict[Bytes32, Bytes] = field(default_factory=dict) + # Ancestor tracking - oldest block accessed via BLOCKHASH + # All headers from this block to parent are needed for chain validation + oldest_accessed_block: Optional[Uint] = None + + # Block metadata for ancestor collection + current_block_number: Uint = field(default_factory=lambda: Uint(0)) + block_headers: List[Bytes] = field(default_factory=list) + @dataclass class State: @@ -317,6 +325,51 @@ def track_bytecode_access(state: State, code: Bytes) -> None: state._witness_state.accessed_bytecodes[code_hash] = code +def track_block_hash_access(state: State, block_number: Uint) -> None: + """ + Track a block hash access for execution witness generation. + + Called when BLOCKHASH opcode or system contracts access a block hash. + Tracks the oldest block accessed since all headers from that block + to the parent are needed for chain validation. + + Parameters + ---------- + state : State + The state with optional witness tracking. + block_number : Uint + The block number being accessed. + """ + if state._witness_state is None: + return + + ws = state._witness_state + if ws.oldest_accessed_block is None or block_number < ws.oldest_accessed_block: + ws.oldest_accessed_block = block_number + + +def set_witness_metadata( + state: State, current_block_number: Uint, block_headers: List[Bytes] +) -> None: + """ + Set block metadata needed for ancestor collection in witness generation. + + Parameters + ---------- + state : State + The state with witness tracking enabled. + current_block_number : Uint + The current block number being executed. + block_headers : List[Bytes] + RLP-encoded headers of previous blocks (up to 256). + """ + if state._witness_state is None: + return + + state._witness_state.current_block_number = current_block_number + state._witness_state.block_headers = block_headers + + def destroy_storage(state: State, address: Address) -> None: """ Completely remove the storage at `address`. @@ -951,11 +1004,26 @@ def make_storage_root_getter( get_storage_root=make_storage_root_getter(addr_storage_root), ) + # Collect ancestors from oldest accessed block to parent (inclusive) + # All headers in this range needed for parent hash chain validation + ancestors: List[Bytes] = [] + if ws.oldest_accessed_block is not None and ws.block_headers: + # Include all headers from oldest accessed to parent (block_number - 1) + for block_num in range( + int(ws.oldest_accessed_block), int(ws.current_block_number) + ): + offset = int(ws.current_block_number) - block_num + if offset <= len(ws.block_headers): + header_rlp = ws.block_headers[-offset] + if header_rlp: + ancestors.append(header_rlp) + # Collect witness from all MPTs witness = Witness( accessed_nodes=dict(main_mpt.witness.accessed_nodes), accessed_keys=set(main_mpt.witness.accessed_keys), bytecodes=sorted(ws.accessed_bytecodes.values()), + ancestors=ancestors, ) for mpt in storage_mpts.values(): witness.accessed_nodes.update(mpt.witness.accessed_nodes) diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index 8aa5bfe2aff..878faee4538 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -182,6 +182,7 @@ class Witness: ) # hash -> RLP encoding accessed_keys: Set[Bytes] = field(default_factory=set) # Original keys bytecodes: List[Bytes] = field(default_factory=list) # Accessed bytecodes + ancestors: List[Bytes] = field(default_factory=list) # RLP-encoded headers @dataclass diff --git a/src/ethereum/forks/osaka/vm/instructions/block.py b/src/ethereum/forks/osaka/vm/instructions/block.py index 43be9e58e23..c6e2f8acfbb 100644 --- a/src/ethereum/forks/osaka/vm/instructions/block.py +++ b/src/ethereum/forks/osaka/vm/instructions/block.py @@ -16,6 +16,7 @@ from .. import Evm from ..gas import GAS_BASE, GAS_BLOCK_HASH, charge_gas from ..stack import pop, push +from ...state import track_block_hash_access def block_hash(evm: Evm) -> None: @@ -57,6 +58,8 @@ def block_hash(evm: Evm) -> None: current_block_hash = evm.message.block_env.block_hashes[ -(current_block_number - block_number) ] + # Track access for witness generation + track_block_hash_access(evm.message.block_env.state, block_number) push(evm.stack, U256.from_be_bytes(current_block_hash)) diff --git a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py index 8809a46688e..7a84c5a859f 100644 --- a/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py +++ b/src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py @@ -313,6 +313,16 @@ def get_witness(self) -> Any: """get_witness function of the fork (Osaka+ only).""" return self._module("state").get_witness + @property + def set_witness_metadata(self) -> Any: + """set_witness_metadata function of the fork (Osaka+ only).""" + return self._module("state").set_witness_metadata + + @property + def track_block_hash_access(self) -> Any: + """track_block_hash_access function of the fork (Osaka+ only).""" + return self._module("state").track_block_hash_access + @property def create_ether(self) -> Any: """create_ether function of the fork.""" diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index afb6a504562..c11b1013ca5 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -361,6 +361,12 @@ def run_state_test(self) -> Any: # Enable witness mode for Osaka+ forks if hasattr(self.fork, "enable_witness_mode"): self.fork.enable_witness_mode(block_env.state) + if hasattr(self.fork, "set_witness_metadata"): + self.fork.set_witness_metadata( + block_env.state, + self.env.block_number, + self.env.block_headers, + ) self.backup_state() if len(self.txs.transactions) > 0: @@ -382,6 +388,11 @@ def run_state_test(self) -> Any: def _run_blockchain_test(self, block_env: Any, block_output: Any) -> None: if self.fork.has_compute_requests_hash: + # Track parent block access for witness (EIP-2935 system call) + if hasattr(self.fork, "track_block_hash_access"): + self.fork.track_block_hash_access( + block_env.state, block_env.number - Uint(1) + ) self.fork.process_unchecked_system_transaction( block_env=block_env, target_address=self.fork.HISTORY_STORAGE_ADDRESS, @@ -456,6 +467,12 @@ def run_blockchain_test(self) -> None: # Enable witness mode for Osaka+ forks if hasattr(self.fork, "enable_witness_mode"): self.fork.enable_witness_mode(block_env.state) + if hasattr(self.fork, "set_witness_metadata"): + self.fork.set_witness_metadata( + block_env.state, + self.env.block_number, + self.env.block_headers, + ) try: self._run_blockchain_test(block_env, block_output) diff --git a/src/ethereum_spec_tools/evm_tools/t8n/env.py b/src/ethereum_spec_tools/evm_tools/t8n/env.py index be719ba7af5..5a1adb8dc4f 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/env.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/env.py @@ -47,6 +47,7 @@ class Env: parent_gas_limit: Optional[Uint] parent_base_fee_per_gas: Optional[Uint] block_hashes: Optional[List[Any]] + block_headers: List[bytes] parent_ommers_hash: Optional[Hash32] ommers: Any parent_beacon_block_root: Optional[Hash32] @@ -277,19 +278,31 @@ def read_block_difficulty(self, data: Any, t8n: "T8N") -> None: def read_block_hashes(self, data: Any) -> None: """ - Read the block hashes. Returns a maximum of 256 block hashes. + Read block hashes and headers. Supports both blockHashes and + blockHeaders inputs. If blockHeaders is provided, hashes are + computed from the RLP-encoded headers. """ - # Read the block hashes block_hashes: List[Any] = [] + block_headers: List[bytes] = [] - # The hex key strings provided might not have standard formatting + # Check if blockHeaders is provided (preferred) + clean_block_headers: Dict[int, bytes] = {} clean_block_hashes: Dict[int, Hash32] = {} - if "blockHashes" in data: + + if "blockHeaders" in data: + # Read headers and compute hashes from them + for key, value in data["blockHeaders"].items(): + int_key = int(key, 16) + header_rlp = hex_to_bytes(value) + clean_block_headers[int_key] = header_rlp + clean_block_hashes[int_key] = Hash32(keccak256(header_rlp)) + elif "blockHashes" in data: + # Fall back to blockHashes if no headers provided for key, value in data["blockHashes"].items(): int_key = int(key, 16) clean_block_hashes[int_key] = Hash32(hex_to_bytes(value)) - # Store a maximum of 256 block hashes. + # Store a maximum of 256 block hashes/headers max_blockhash_count = min(Uint(256), self.block_number) for number in range( self.block_number - max_blockhash_count, self.block_number @@ -299,7 +312,13 @@ def read_block_hashes(self, data: Any) -> None: else: block_hashes.append(None) + if number in clean_block_headers.keys(): + block_headers.append(clean_block_headers[number]) + else: + block_headers.append(b"") + self.block_hashes = block_hashes + self.block_headers = block_headers def read_ommers(self, data: Any, t8n: "T8N") -> None: """ diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index a86f067d798..4631f95e7e2 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -336,16 +336,19 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: # Extract execution witness for Osaka+ forks if hasattr(t8n.fork, "get_witness"): witness = t8n.fork.get_witness(block_env.state) - if witness.accessed_nodes or witness.bytecodes: - self.execution_witness = { - "nodes": [ - "0x" + node_rlp.hex() - for node_rlp in witness.accessed_nodes.values() - ], - "bytecodes": [ - "0x" + bytecode.hex() for bytecode in witness.bytecodes - ], - } + self.execution_witness = { + "nodes": [ + "0x" + node_rlp.hex() + for node_rlp in witness.accessed_nodes.values() + ], + "bytecodes": [ + "0x" + bytecode.hex() for bytecode in witness.bytecodes + ], + "ancestors": [ + "0x" + header_rlp.hex() + for header_rlp in witness.ancestors + ], + } @staticmethod From 59b465f4e5bc3c8a33783f4ac01dac1a60d1a0e4 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 19:14:43 -0300 Subject: [PATCH 5/8] improvements and lints --- .../test_types/block_types.py | 4 +- src/ethereum/forks/osaka/state.py | 141 ++++++++++++------ src/ethereum/forks/osaka/trie.py | 17 +-- .../evm_tools/t8n/__init__.py | 31 ++-- .../evm_tools/t8n/t8n_types.py | 3 +- whitelist.txt | 1 + 6 files changed, 119 insertions(+), 78 deletions(-) diff --git a/packages/testing/src/execution_testing/test_types/block_types.py b/packages/testing/src/execution_testing/test_types/block_types.py index 3f5077a8189..a7ad164f653 100644 --- a/packages/testing/src/execution_testing/test_types/block_types.py +++ b/packages/testing/src/execution_testing/test_types/block_types.py @@ -137,7 +137,9 @@ def strip_computed_fields(cls, data: Any) -> Any: parent_beacon_block_root: Hash | None = Field(None) block_hashes: Dict[ZeroPaddedHexNumber, Hash] = Field(default_factory=dict) - block_headers: Dict[ZeroPaddedHexNumber, Bytes] = Field(default_factory=dict) + block_headers: Dict[ZeroPaddedHexNumber, Bytes] = Field( + default_factory=dict + ) ommers: List[Hash] = Field(default_factory=list) withdrawals: List[Withdrawal] | None = Field(None) extra_data: Bytes = Field(Bytes(b"\x00"), exclude=True) diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index 66d5c178f2c..c37e5278c51 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -77,6 +77,10 @@ class WitnessState: current_block_number: Uint = field(default_factory=lambda: Uint(0)) block_headers: List[Bytes] = field(default_factory=list) + # Pre-execution MPTs + _main_mpt: Optional["IncrementalMPT"] = None + _storage_mpts: Optional[Dict[Address, "IncrementalMPT"]] = None + @dataclass class State: @@ -311,12 +315,10 @@ def track_bytecode_access(state: State, code: Bytes) -> None: The state with optional witness tracking. code : Bytes The bytecode being accessed. - """ - if state._witness_state is None: - return - # Skip empty bytecode (EOAs) - if len(code) == 0: + """ + # Skip if witness mode disabled or empty bytecode (EOAs) + if state._witness_state is None or len(code) == 0: return # Compute hash and store for deduplication @@ -339,12 +341,17 @@ def track_block_hash_access(state: State, block_number: Uint) -> None: The state with optional witness tracking. block_number : Uint The block number being accessed. + """ if state._witness_state is None: return ws = state._witness_state - if ws.oldest_accessed_block is None or block_number < ws.oldest_accessed_block: + is_oldest = ( + ws.oldest_accessed_block is None + or block_number < ws.oldest_accessed_block + ) + if is_oldest: ws.oldest_accessed_block = block_number @@ -362,6 +369,7 @@ def set_witness_metadata( The current block number being executed. block_headers : List[Bytes] RLP-encoded headers of previous blocks (up to 256). + """ if state._witness_state is None: return @@ -386,10 +394,8 @@ def destroy_storage(state: State, address: Address) -> None: if state._witness_state is not None: ws = state._witness_state if address in ws.pre_block_storage_tries_data: - if address not in ws.dirty_storage: - ws.dirty_storage[address] = set() # Mark all pre-block storage keys as dirty (they're now deleted) - ws.dirty_storage[address].update( + ws.dirty_storage.setdefault(address, set()).update( ws.pre_block_storage_tries_data[address].keys() ) @@ -442,9 +448,7 @@ def get_storage(state: State, address: Address, key: Bytes32) -> U256: # Track accessed storage for execution witness generation if state._witness_state is not None: ws = state._witness_state - if address not in ws.accessed_storage: - ws.accessed_storage[address] = set() - ws.accessed_storage[address].add(key) + ws.accessed_storage.setdefault(address, set()).add(key) trie = state._storage_tries.get(address) if trie is None: @@ -485,10 +489,7 @@ def set_storage( # Track dirty storage for deferred witness generation if state._witness_state is not None: - ws = state._witness_state - if address not in ws.dirty_storage: - ws.dirty_storage[address] = set() - ws.dirty_storage[address].add(key) + state._witness_state.dirty_storage.setdefault(address, set()).add(key) def storage_root(state: State, address: Address) -> Root: @@ -536,14 +537,16 @@ def get_storage_root(address: Address) -> Root: return storage_root(state, address) # Calculate root using patricialize (existing implementation) - patricialize_root = root(state._main_trie, get_storage_root=get_storage_root) + patricialize_root = root( + state._main_trie, get_storage_root=get_storage_root + ) # If witness mode is enabled, verify IncrementalMPT produces same root if state._witness_state is not None: - incremental_root, _ = generate_witness(state) - assert patricialize_root == incremental_root, ( + inc_root = incremental_state_root(state) + assert patricialize_root == inc_root, ( f"Root mismatch! patricialize={patricialize_root.hex()} " - f"incremental={incremental_root.hex()}" + f"incremental={inc_root.hex()}" ) return patricialize_root @@ -856,7 +859,7 @@ def enable_witness_mode(state: State) -> None: The state to enable witness mode on. """ - assert not state._snapshots, "Cannot enable witness mode during transaction" + assert not state._snapshots, "Cannot enable witness during transaction" state._witness_state = WitnessState( pre_block_main_trie_data=dict(state._main_trie._data), @@ -885,34 +888,27 @@ def is_witness_mode_enabled(state: State) -> bool: return state._witness_state is not None -def generate_witness(state: State) -> Tuple[Root, Witness]: +def _build_witness_mpts(state: State) -> None: """ - Build MPT from pre-block state, generate execution witness, return root. + Build and cache the IncrementalMPTs for witness generation. - This is called after all block execution completes. It builds a fresh - IncrementalMPT from the pre-block state, traverses read paths to record - pre-state nodes, then applies the final diff for writes. This produces - an execution witness containing nodes needed for: - - Verifying pre-state values that were read - - Re-executing the block - - Computing the post-state root + This builds the MPTs from pre-block state, applies all reads and writes, + which records witness nodes as a side effect. The MPTs are cached in + WitnessState for reuse by root computation and witness extraction. Parameters ---------- state : The state with witness tracking enabled. - Returns - ------- - root : `Root` - The state root computed via IncrementalMPT. - witness : `Witness` - The execution witness containing accessed nodes. - """ assert state._witness_state is not None ws = state._witness_state + # Already built + if ws._main_mpt is not None: + return + # Build fresh MPT from pre-block state pre_main_trie: Trie[Address, Optional[Account]] = Trie( secured=True, default=None @@ -991,19 +987,73 @@ def get_pre_storage_root(address: Address) -> Root: else: addr_storage_root = EMPTY_TRIE_ROOT - # Use a closure that captures the specific storage root for this address - def make_storage_root_getter( - sr: Root, - ) -> Callable[[Address], Root]: - return lambda _: sr - mpt_set( main_mpt, address, account, - get_storage_root=make_storage_root_getter(addr_storage_root), + get_storage_root=lambda _, sr=addr_storage_root: sr, ) + # Cache the built MPTs + ws._main_mpt = main_mpt + ws._storage_mpts = storage_mpts + + +def incremental_state_root(state: State) -> Root: + """ + Compute state root using IncrementalMPT. + + Builds the MPTs if not already built, then returns the root. + The MPT nodes cache their hashes, so subsequent calls are fast. + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + root : `Root` + The state root computed via IncrementalMPT. + + """ + assert state._witness_state is not None + _build_witness_mpts(state) + return mpt_root(state._witness_state._main_mpt) + + +def generate_witness(state: State) -> Tuple[Root, Witness]: + """ + Generate execution witness from the cached MPTs. + + Builds the MPTs if not already built, then extracts the witness + data from them. The witness contains nodes needed for: + - Verifying pre-state values that were read + - Re-executing the block + - Computing the post-state root + + Parameters + ---------- + state : + The state with witness tracking enabled. + + Returns + ------- + root : `Root` + The state root computed via IncrementalMPT. + witness : `Witness` + The execution witness containing accessed nodes. + + """ + assert state._witness_state is not None + ws = state._witness_state + + # Ensure MPTs are built + _build_witness_mpts(state) + + main_mpt = ws._main_mpt + storage_mpts = ws._storage_mpts + # Collect ancestors from oldest accessed block to parent (inclusive) # All headers in this range needed for parent hash chain validation ancestors: List[Bytes] = [] @@ -1036,9 +1086,6 @@ def get_witness(state: State) -> Witness: """ Get the collected witness data from the state. - This generates the witness by building an MPT from the pre-block state - and applying only the final diff. - Parameters ---------- state : diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index 878faee4538..78815ebc48b 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -629,9 +629,7 @@ def _build_mutable_tree( return MutableExtensionNode(key_segment=prefix, child=child) # Branch node case - branches: List[MutableMapping[Bytes, Bytes]] = [] - for _ in range(16): - branches.append({}) + branches: List[MutableMapping[Bytes, Bytes]] = [{} for _ in range(16)] value = b"" for key in obj: @@ -726,8 +724,7 @@ def _encode_mutable_node(node: MutableNode) -> Extended: ) elif isinstance(node, MutableBranchNode): children_encoded = [ - _encode_mutable_node_to_extended(child) - for child in node.children + _encode_mutable_node_to_extended(child) for child in node.children ] return children_encoded + [node.value] else: @@ -943,7 +940,7 @@ def _mpt_insert_node( def _insert_into_leaf( - mpt: IncrementalMPT, + _mpt: IncrementalMPT, node: MutableLeafNode, key: Bytes, value: Bytes, @@ -1024,9 +1021,7 @@ def _insert_into_extension( # Extension needs to be split if prefix_len > 0: # Partial match - create new extension for common prefix - new_child = _split_extension( - node, remaining_key, value, prefix_len - ) + new_child = _split_extension(node, remaining_key, value, prefix_len) return MutableExtensionNode( key_segment=segment[:prefix_len], child=new_child ) @@ -1191,7 +1186,9 @@ def _delete_from_branch( return _collapse_branch(mpt, node) -def _collapse_branch(mpt: IncrementalMPT, node: MutableBranchNode) -> MutableNode: +def _collapse_branch( + mpt: IncrementalMPT, node: MutableBranchNode +) -> MutableNode: """Collapse a branch node if it has only one child and no value.""" non_empty = [(i, c) for i, c in enumerate(node.children) if c is not None] diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index c11b1013ca5..3cb55c58b38 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -350,15 +350,8 @@ def pay_block_rewards(self, block_reward: U256, block_env: Any) -> None: block_env.state, ommer.coinbase, ommer_miner_reward ) - def run_state_test(self) -> Any: - """ - Apply a single transaction on pre-state. No system operations - are performed. - """ - block_env = self.block_environment() - block_output = self.fork.BlockOutput() - - # Enable witness mode for Osaka+ forks + def _enable_witness_mode(self, block_env: Any) -> None: + """Enable witness tracking mode if supported by the fork (Osaka+).""" if hasattr(self.fork, "enable_witness_mode"): self.fork.enable_witness_mode(block_env.state) if hasattr(self.fork, "set_witness_metadata"): @@ -368,6 +361,16 @@ def run_state_test(self) -> Any: self.env.block_headers, ) + def run_state_test(self) -> Any: + """ + Apply a single transaction on pre-state. No system operations + are performed. + """ + block_env = self.block_environment() + block_output = self.fork.BlockOutput() + + self._enable_witness_mode(block_env) + self.backup_state() if len(self.txs.transactions) > 0: tx = self.txs.transactions[0] @@ -464,15 +467,7 @@ def run_blockchain_test(self) -> None: block_env = self.block_environment() block_output = self.fork.BlockOutput() - # Enable witness mode for Osaka+ forks - if hasattr(self.fork, "enable_witness_mode"): - self.fork.enable_witness_mode(block_env.state) - if hasattr(self.fork, "set_witness_metadata"): - self.fork.set_witness_metadata( - block_env.state, - self.env.block_number, - self.env.block_headers, - ) + self._enable_witness_mode(block_env) try: self._run_blockchain_test(block_env, block_output) diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index 4631f95e7e2..2d226bdf50c 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -345,8 +345,7 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: "0x" + bytecode.hex() for bytecode in witness.bytecodes ], "ancestors": [ - "0x" + header_rlp.hex() - for header_rlp in witness.ancestors + "0x" + header_rlp.hex() for header_rlp in witness.ancestors ], } diff --git a/whitelist.txt b/whitelist.txt index 2cf28a99287..ca3c50141de 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -1293,6 +1293,7 @@ webSocket wei wfile whitelist +ws wikipedia wordlist Words2 From d96947df2db7dd3f07dbf1184cd11e2b0d9b45f2 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 19:19:22 -0300 Subject: [PATCH 6/8] checks --- src/ethereum/forks/osaka/state.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index c37e5278c51..c92a99a66d5 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -942,10 +942,9 @@ def get_pre_storage_root(address: Address) -> Root: if read_only_keys: if address not in storage_mpts: # Storage was accessed but didn't exist pre-block - empty_trie: Trie[Bytes32, U256] = Trie( - secured=True, default=U256(0) + storage_mpts[address] = build_mpt( + Trie(secured=True, default=U256(0)) ) - storage_mpts[address] = build_mpt(empty_trie) for key in read_only_keys: mpt_get(storage_mpts[address], key) @@ -954,10 +953,9 @@ def get_pre_storage_root(address: Address) -> Root: for address, dirty_keys in ws.dirty_storage.items(): if address not in storage_mpts: # New storage created during block - empty_trie: Trie[Bytes32, U256] = Trie( - secured=True, default=U256(0) + storage_mpts[address] = build_mpt( + Trie(secured=True, default=U256(0)) ) - storage_mpts[address] = build_mpt(empty_trie) storage_trie = state._storage_tries.get(address) for key in dirty_keys: @@ -987,11 +985,16 @@ def get_pre_storage_root(address: Address) -> Root: else: addr_storage_root = EMPTY_TRIE_ROOT + def get_storage_root_fn( + _: Address, sr: Root = addr_storage_root + ) -> Root: + return sr + mpt_set( main_mpt, address, account, - get_storage_root=lambda _, sr=addr_storage_root: sr, + get_storage_root=get_storage_root_fn, ) # Cache the built MPTs @@ -1019,6 +1022,8 @@ def incremental_state_root(state: State) -> Root: """ assert state._witness_state is not None _build_witness_mpts(state) + assert state._witness_state._main_mpt is not None + assert state._witness_state._storage_mpts is not None return mpt_root(state._witness_state._main_mpt) @@ -1050,6 +1055,8 @@ def generate_witness(state: State) -> Tuple[Root, Witness]: # Ensure MPTs are built _build_witness_mpts(state) + assert ws._main_mpt is not None + assert ws._storage_mpts is not None main_mpt = ws._main_mpt storage_mpts = ws._storage_mpts From c3a1fbb89905b6e3dac8acdc67a596b7383c5709 Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Sat, 17 Jan 2026 23:20:52 -0300 Subject: [PATCH 7/8] improvements and refactor --- src/ethereum/forks/osaka/state.py | 55 ++++++++++++++++++------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index c92a99a66d5..018888c3ce9 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -55,8 +55,8 @@ class WitnessState: """ # Pre-block state (preserved at enable_witness_mode time) - pre_block_main_trie_data: Dict[Address, Optional[Account]] - pre_block_storage_tries_data: Dict[Address, Dict[Bytes32, U256]] + pre_state_accounts: Dict[Address, Optional[Account]] + pre_state_storages: Dict[Address, Dict[Bytes32, U256]] # Dirty tracking during execution (writes) dirty_accounts: Set[Address] = field(default_factory=set) @@ -393,10 +393,10 @@ def destroy_storage(state: State, address: Address) -> None: # Track all pre-block storage keys as dirty for witness generation if state._witness_state is not None: ws = state._witness_state - if address in ws.pre_block_storage_tries_data: + if address in ws.pre_state_storages: # Mark all pre-block storage keys as dirty (they're now deleted) ws.dirty_storage.setdefault(address, set()).update( - ws.pre_block_storage_tries_data[address].keys() + ws.pre_state_storages[address].keys() ) if address in state._storage_tries: @@ -862,8 +862,8 @@ def enable_witness_mode(state: State) -> None: assert not state._snapshots, "Cannot enable witness during transaction" state._witness_state = WitnessState( - pre_block_main_trie_data=dict(state._main_trie._data), - pre_block_storage_tries_data={ + pre_state_accounts=dict(state._main_trie._data), + pre_state_storages={ addr: dict(trie._data) for addr, trie in state._storage_tries.items() }, @@ -913,11 +913,11 @@ def _build_witness_mpts(state: State) -> None: pre_main_trie: Trie[Address, Optional[Account]] = Trie( secured=True, default=None ) - pre_main_trie._data = dict(ws.pre_block_main_trie_data) + pre_main_trie._data = dict(ws.pre_state_accounts) # Build pre-block storage MPTs storage_mpts: Dict[Address, IncrementalMPT[Bytes32, U256]] = {} - for address, data in ws.pre_block_storage_tries_data.items(): + for address, data in ws.pre_state_storages.items(): pre_storage_trie: Trie[Bytes32, U256] = Trie( secured=True, default=U256(0) ) @@ -925,10 +925,8 @@ def _build_witness_mpts(state: State) -> None: storage_mpts[address] = build_mpt(pre_storage_trie) def get_pre_storage_root(address: Address) -> Root: - if address in ws.pre_block_storage_tries_data: - pre_trie: Trie[Bytes32, U256] = Trie(secured=True, default=U256(0)) - pre_trie._data = dict(ws.pre_block_storage_tries_data[address]) - return root(pre_trie) + if address in storage_mpts: + return mpt_root(storage_mpts[address]) return EMPTY_TRIE_ROOT main_mpt = build_mpt(pre_main_trie, get_pre_storage_root) @@ -939,15 +937,12 @@ def get_pre_storage_root(address: Address) -> Root: dirty_keys = ws.dirty_storage.get(address, set()) read_only_keys = accessed_keys - dirty_keys - if read_only_keys: - if address not in storage_mpts: - # Storage was accessed but didn't exist pre-block - storage_mpts[address] = build_mpt( - Trie(secured=True, default=U256(0)) - ) + # Skip if no pre-block storage - empty storage proof is in account's storage_root + if not read_only_keys or address not in storage_mpts: + continue - for key in read_only_keys: - mpt_get(storage_mpts[address], key) + for key in read_only_keys: + mpt_get(storage_mpts[address], key) # 2. Apply dirty storage (writes) for address, dirty_keys in ws.dirty_storage.items(): @@ -958,9 +953,16 @@ def get_pre_storage_root(address: Address) -> Root: ) storage_trie = state._storage_tries.get(address) + # First pass: inserts and updates + for key in dirty_keys: + value = trie_get(storage_trie, key) if storage_trie else U256(0) + if value != 0: + mpt_set(storage_mpts[address], key, value) + # Second pass: deletions for key in dirty_keys: value = trie_get(storage_trie, key) if storage_trie else U256(0) - mpt_set(storage_mpts[address], key, value) + if value == 0: + mpt_set(storage_mpts[address], key, value) # Accounts are "dirty" if: # - Account fields changed (nonce/balance/code) - tracked in dirty_accounts @@ -980,10 +982,17 @@ def get_pre_storage_root(address: Address) -> Root: # Get storage root for this account if address in storage_mpts: addr_storage_root = mpt_root(storage_mpts[address]) - elif address in state._storage_tries: - addr_storage_root = root(state._storage_tries[address]) + # Verify invariant: MPT root must match state storage root + # (if storage was fully cleared, it won't be in state._storage_tries) + if address in state._storage_tries: + assert addr_storage_root == root(state._storage_tries[address]) + else: + assert addr_storage_root == EMPTY_TRIE_ROOT else: + # Verify invariant: no storage in state either + assert address not in state._storage_tries addr_storage_root = EMPTY_TRIE_ROOT + def get_storage_root_fn( _: Address, sr: Root = addr_storage_root From 69002db06ca12d3dc617b6f6f7feb9a42b04966f Mon Sep 17 00:00:00 2001 From: jsign Date: Tue, 20 Jan 2026 12:27:58 -0300 Subject: [PATCH 8/8] improvements Signed-off-by: jsign --- src/ethereum/forks/osaka/fork.py | 4 +- src/ethereum/forks/osaka/state.py | 90 +++++++------------ src/ethereum/forks/osaka/trie.py | 69 +++++++++----- .../forks/osaka/vm/instructions/block.py | 2 +- .../evm_tools/t8n/t8n_types.py | 13 +-- 5 files changed, 92 insertions(+), 86 deletions(-) diff --git a/src/ethereum/forks/osaka/fork.py b/src/ethereum/forks/osaka/fork.py index eb6725f4719..7c81484c1f7 100644 --- a/src/ethereum/forks/osaka/fork.py +++ b/src/ethereum/forks/osaka/fork.py @@ -54,11 +54,8 @@ State, TransientStorage, destroy_account, - enable_witness_mode, get_account, - get_witness, increment_nonce, - is_witness_mode_enabled, modify_state, set_account_balance, set_witness_metadata, @@ -213,6 +210,7 @@ def get_last_256_block_headers(chain: BlockChain) -> List[Bytes]: ------- recent_block_headers : `List[Bytes]` RLP-encoded headers of recent 256 blocks in order of increasing number. + """ recent_blocks = chain.blocks[-256:] if len(recent_blocks) == 0: diff --git a/src/ethereum/forks/osaka/state.py b/src/ethereum/forks/osaka/state.py index 018888c3ce9..c9c2f3a3533 100644 --- a/src/ethereum/forks/osaka/state.py +++ b/src/ethereum/forks/osaka/state.py @@ -192,11 +192,6 @@ def rollback_transaction( transient_storage : TransientStorage The transient storage of the transaction. - Note: Dirty tracking for witness generation persists across rollbacks. - This is correct because the witness needs to capture all nodes that - were accessed during execution, regardless of whether transactions - succeeded or failed. - """ state._main_trie, state._storage_tries = state._snapshots.pop() if not state._snapshots: @@ -487,7 +482,7 @@ def set_storage( if trie._data == {}: del state._storage_tries[address] - # Track dirty storage for deferred witness generation + # Track dirty storage for witness generation if state._witness_state is not None: state._witness_state.dirty_storage.setdefault(address, set()).add(key) @@ -841,11 +836,6 @@ def set_transient_storage( del transient_storage._tries[address] -# ============================================================================= -# Witness Generation Functions -# ============================================================================= - - def enable_witness_mode(state: State) -> None: """ Enable witness tracking mode for the state. @@ -909,50 +899,45 @@ def _build_witness_mpts(state: State) -> None: if ws._main_mpt is not None: return - # Build fresh MPT from pre-block state - pre_main_trie: Trie[Address, Optional[Account]] = Trie( - secured=True, default=None - ) - pre_main_trie._data = dict(ws.pre_state_accounts) - # Build pre-block storage MPTs storage_mpts: Dict[Address, IncrementalMPT[Bytes32, U256]] = {} for address, data in ws.pre_state_storages.items(): - pre_storage_trie: Trie[Bytes32, U256] = Trie( - secured=True, default=U256(0) + storage_mpts[address] = build_mpt( + dict(data), secured=True, default=U256(0) ) - pre_storage_trie._data = dict(data) - storage_mpts[address] = build_mpt(pre_storage_trie) def get_pre_storage_root(address: Address) -> Root: if address in storage_mpts: return mpt_root(storage_mpts[address]) return EMPTY_TRIE_ROOT - main_mpt = build_mpt(pre_main_trie, get_pre_storage_root) + main_mpt = build_mpt( + dict(ws.pre_state_accounts), + secured=True, + default=None, + get_storage_root=get_pre_storage_root, + ) - # 1. Traverse read-only storage keys (accessed but not dirty) - # This records pre-state paths for values that were read + # 1. Do read-only storages accesses for address, accessed_keys in ws.accessed_storage.items(): - dirty_keys = ws.dirty_storage.get(address, set()) - read_only_keys = accessed_keys - dirty_keys - - # Skip if no pre-block storage - empty storage proof is in account's storage_root - if not read_only_keys or address not in storage_mpts: + if address not in storage_mpts: continue - for key in read_only_keys: + for key in accessed_keys: mpt_get(storage_mpts[address], key) - # 2. Apply dirty storage (writes) + # 2. Apply dirty storage to storages (writes) for address, dirty_keys in ws.dirty_storage.items(): if address not in storage_mpts: # New storage created during block storage_mpts[address] = build_mpt( - Trie(secured=True, default=U256(0)) + {}, secured=True, default=U256(0) ) storage_trie = state._storage_tries.get(address) + # We do two passes to ensure deletions are processed after + # inserts/updates to minimize the number of nodes touched + # in the MPT. # First pass: inserts and updates for key in dirty_keys: value = trie_get(storage_trie, key) if storage_trie else U256(0) @@ -969,21 +954,21 @@ def get_pre_storage_root(address: Address) -> Root: # - Storage changed (storage root changed) - tracked in dirty_storage all_dirty_accounts = ws.dirty_accounts | set(ws.dirty_storage.keys()) - # 3. Traverse read-only accounts (accessed but not dirty) - # This records pre-state paths for accounts that were read - read_only_accounts = ws.accessed_accounts - all_dirty_accounts - for address in read_only_accounts: + # 3. Traverse accounts that were read + for address in ws.accessed_accounts: mpt_get(main_mpt, address) - # 4. Apply dirty accounts (writes, with current storage roots) + # 4. Apply dirty accounts for address in all_dirty_accounts: + # Get new account data from usual trie account = trie_get(state._main_trie, address) # Get storage root for this account if address in storage_mpts: addr_storage_root = mpt_root(storage_mpts[address]) # Verify invariant: MPT root must match state storage root - # (if storage was fully cleared, it won't be in state._storage_tries) + # (if storage was fully cleared, it won't be in + # state._storage_tries) if address in state._storage_tries: assert addr_storage_root == root(state._storage_tries[address]) else: @@ -992,7 +977,6 @@ def get_pre_storage_root(address: Address) -> Root: # Verify invariant: no storage in state either assert address not in state._storage_tries addr_storage_root = EMPTY_TRIE_ROOT - def get_storage_root_fn( _: Address, sr: Root = addr_storage_root @@ -1038,13 +1022,7 @@ def incremental_state_root(state: State) -> Root: def generate_witness(state: State) -> Tuple[Root, Witness]: """ - Generate execution witness from the cached MPTs. - - Builds the MPTs if not already built, then extracts the witness - data from them. The witness contains nodes needed for: - - Verifying pre-state values that were read - - Re-executing the block - - Computing the post-state root + Generate execution witness. Parameters ---------- @@ -1073,16 +1051,16 @@ def generate_witness(state: State) -> Tuple[Root, Witness]: # Collect ancestors from oldest accessed block to parent (inclusive) # All headers in this range needed for parent hash chain validation ancestors: List[Bytes] = [] - if ws.oldest_accessed_block is not None and ws.block_headers: - # Include all headers from oldest accessed to parent (block_number - 1) - for block_num in range( - int(ws.oldest_accessed_block), int(ws.current_block_number) - ): - offset = int(ws.current_block_number) - block_num - if offset <= len(ws.block_headers): - header_rlp = ws.block_headers[-offset] - if header_rlp: - ancestors.append(header_rlp) + assert ws.oldest_accessed_block is not None and ws.block_headers + # Include all headers from oldest accessed to parent (block_number - 1) + for block_num in range( + int(ws.oldest_accessed_block), int(ws.current_block_number) + ): + offset = int(ws.current_block_number) - block_num + if offset <= len(ws.block_headers): + header_rlp = ws.block_headers[-offset] + if header_rlp: + ancestors.append(header_rlp) # Collect witness from all MPTs witness = Witness( diff --git a/src/ethereum/forks/osaka/trie.py b/src/ethereum/forks/osaka/trie.py index 78815ebc48b..ab7820fb7a5 100644 --- a/src/ethereum/forks/osaka/trie.py +++ b/src/ethereum/forks/osaka/trie.py @@ -420,18 +420,21 @@ def bytes_to_nibble_list(bytes_: Bytes) -> Bytes: return Bytes(nibble_list) -def _prepare_trie( - trie: Trie[K, V], +def _prepare_data( + data: Mapping[K, V], + secured: bool, get_storage_root: Optional[Callable[[Address], Root]] = None, ) -> Mapping[Bytes, Bytes]: """ - Prepares the trie for root calculation. Removes values that are empty, + Prepares data for trie root calculation. Removes values that are empty, hashes the keys (if `secured == True`) and encodes all the nodes. Parameters ---------- - trie : - The `Trie` to prepare. + data : + The key-value data to prepare. + secured : + Whether keys should be hashed. get_storage_root : Function to get the storage root of an account. Needed to encode `Account` objects. @@ -444,7 +447,7 @@ def _prepare_trie( """ mapped: MutableMapping[Bytes, Bytes] = {} - for preimage, value in trie._data.items(): + for preimage, value in data.items(): if isinstance(value, Account): assert get_storage_root is not None address = Address(preimage) @@ -454,7 +457,7 @@ def _prepare_trie( if encoded_value == b"": raise AssertionError key: Bytes - if trie.secured: + if secured: # "secure" tries hash keys once before construction key = keccak256(preimage) else: @@ -464,6 +467,31 @@ def _prepare_trie( return mapped +def _prepare_trie( + trie: Trie[K, V], + get_storage_root: Optional[Callable[[Address], Root]] = None, +) -> Mapping[Bytes, Bytes]: + """ + Prepares the trie for root calculation. Removes values that are empty, + hashes the keys (if `secured == True`) and encodes all the nodes. + + Parameters + ---------- + trie : + The `Trie` to prepare. + get_storage_root : + Function to get the storage root of an account. Needed to encode + `Account` objects. + + Returns + ------- + out : `Mapping[ethereum.base_types.Bytes, Node]` + Object with keys mapped to nibble-byte form. + + """ + return _prepare_data(trie._data, trie.secured, get_storage_root) + + def root( trie: Trie[K, V], get_storage_root: Optional[Callable[[Address], Root]] = None, @@ -574,11 +602,6 @@ def patricialize( ) -# ============================================================================= -# Incremental MPT Functions -# ============================================================================= - - def _build_mutable_tree( obj: Mapping[Bytes, Bytes], level: Uint ) -> MutableNode: @@ -646,19 +669,25 @@ def _build_mutable_tree( def build_mpt( - trie: Trie[K, V], + data: Mapping[K, V], + secured: bool, + default: V, get_storage_root: Optional[Callable[[Address], Root]] = None, ) -> IncrementalMPT[K, V]: """ - Build an IncrementalMPT from an existing Trie. + Build an IncrementalMPT from key-value data. This is called with the pre-execution state to create a mutable tree structure that can be updated in-place during execution. Parameters ---------- - trie : - The source Trie to build from. + data : + The source key-value data to build from. + secured : + Whether to hash keys before insertion. + default : + Default value for missing keys. get_storage_root : Function to get the storage root of an account. @@ -668,14 +697,14 @@ def build_mpt( An incremental MPT with the same data. """ - prepared = _prepare_trie(trie, get_storage_root) + prepared = _prepare_data(data, secured, get_storage_root) root_node = _build_mutable_tree(prepared, Uint(0)) return IncrementalMPT( - secured=trie.secured, - default=trie.default, + secured=secured, + default=default, root_node=root_node, - _data=copy.copy(trie._data), + _data=dict(data), ) diff --git a/src/ethereum/forks/osaka/vm/instructions/block.py b/src/ethereum/forks/osaka/vm/instructions/block.py index c6e2f8acfbb..78f35a0ff48 100644 --- a/src/ethereum/forks/osaka/vm/instructions/block.py +++ b/src/ethereum/forks/osaka/vm/instructions/block.py @@ -13,10 +13,10 @@ from ethereum_types.numeric import U256, Uint +from ...state import track_block_hash_access from .. import Evm from ..gas import GAS_BASE, GAS_BLOCK_HASH, charge_gas from ..stack import pop, push -from ...state import track_block_hash_access def block_hash(evm: Evm) -> None: diff --git a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py index 2d226bdf50c..a4d8bd2c299 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py @@ -337,10 +337,12 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: if hasattr(t8n.fork, "get_witness"): witness = t8n.fork.get_witness(block_env.state) self.execution_witness = { - "nodes": [ - "0x" + node_rlp.hex() - for node_rlp in witness.accessed_nodes.values() - ], + "nodes": sorted( + [ + "0x" + node_rlp.hex() + for node_rlp in witness.accessed_nodes.values() + ] + ), "bytecodes": [ "0x" + bytecode.hex() for bytecode in witness.bytecodes ], @@ -349,7 +351,6 @@ def update(self, t8n: "T8N", block_env: Any, block_output: Any) -> None: ], } - @staticmethod def _block_access_list_to_json(account_changes: Any) -> Any: """ @@ -415,7 +416,7 @@ def _block_access_list_to_json(account_changes: Any) -> Any: json_account_changes.append(account_data) return json_account_changes - + def json_encode_receipts(self) -> Any: """ Encode receipts to JSON.